在Python和Matplotlib中,很容易将绘图显示为弹出窗口或将绘图保存为PNG文件。我怎样才能将绘图保存为RGB格式的numpy数组?
答案 0 :(得分:53)
当您需要与保存的绘图进行像素到像素的比较时,这是单元测试等方便的技巧。
一种方法是使用fig.canvas.tostring_rgb
,然后numpy.fromstring
使用适当的dtype。还有其他方法,但这是我倾向于使用的方式。
E.g。
import matplotlib.pyplot as plt
import numpy as np
# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)
# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
答案 1 :(得分:15)
@JUN_NETWORKS的答案有一个更简单的选项。除了将图形保存为png
之外,还可以使用其他格式,例如raw
或rgba
,并跳过cv2
解码步骤。
换句话说,实际的绘图到numpy转换可以归结为:
io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()
希望,这会有所帮助。
答案 2 :(得分:3)
如果有人想要一个即插即用的解决方案,而无需修改任何先前的代码(获得对 pyplot 图和所有内容的引用),以下对我有用。只需在所有 pyplot
语句之后添加它,即在 pyplot.show()
canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))
答案 3 :(得分:2)
有人提出了这样的方法
np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
当然,此代码有效。但是,输出的numpy数组图像分辨率太低。
我的提案代码是这个。
import io
import cv2
import numpy as np
import matplotlib.pyplot as plt
# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)
x = np.linspace(-np.pi, np.pi)
ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.plot(x, np.sin(x), label="sin")
ax.legend()
ax.set_title("sin(x)")
# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=180)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)
此代码效果很好。
如果在dpi参数上设置较大的数字,则可以将高分辨率图像作为numpy数组获得。
答案 4 :(得分:1)
是时候对您的解决方案进行基准测试了。
import io
import matplotlib
matplotlib.use('agg') # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
ax.plot(range(10))
def plot1():
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
def plot2():
with io.BytesIO() as buff:
fig.savefig(buff, format='png')
buff.seek(0)
im = plt.imread(buff)
def plot3():
with io.BytesIO() as buff:
fig.savefig(buff, format='raw')
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
在这种情况下,IO 原始缓冲区是将 matplotlib 图转换为 numpy 数组的最快速度。
补充说明:
如果您无法访问图形,您可以随时从轴中提取它:
fig = ax.figure
如果您需要 channel x height x width
格式的数组,请执行
im = im.transpose((2, 0, 1))
。