循环更新图像中的Pyplot子图

时间:2019-06-09 11:31:43

标签: python numpy matplotlib

我认为我做的事情确实很愚蠢,但是我不太明白。我想创建一个类来显示一组图像作为子图;显示应该从循环内手动更新。这是我为尝试执行此操作而创建的课程:

import matplotlib.pyplot as plt
import numpy as np

class tensor_plot:
    def __init__(self, tensor_shape, nrows=1):
        self.img_height, self.img_width, self.num_imgs = tensor_shape
        self.nrows = nrows
        self.ncols = self.num_imgs // nrows
        assert(self.ncols*self.nrows == self.num_imgs)
        self.fig, self.a = plt.subplots(self.nrows, self.ncols, sharex='col', sharey='row')
        for (row, col) in zip(range(self.nrows), range(self.ncols)):
            self.a[row, col] = plt.imshow(np.zeros([self.img_height, self.img_width]))

    def update(self, tensor):
        n=0
        for row in range(self.nrows):
            for col in range(self.ncols):
                self.a[row,col].set_data(tensor[:,:,n].squeeze())
                n += 1
        plt.show()

当我尝试传递张量进行更新时,它说没有set_data属性。但是使用dir就有这样的属性。

In [322]: tp = tensor_plot(l10.shape, 4)

In [323]: tp.update(l10)
AttributeError: 'AxesSubplot' object has no attribute 'set_data'


In [324]: dir(tp.a[0,0])
Out[324]: 
['_A',
...
 'set_data',
...
 'update_from',
 'write_png',
 'zorder']

如果我在循环中添加行print(dir(self.a[row,col])),则确实没有set_data!相同的评论适用于imshow

有什么想法吗?

1 个答案:

答案 0 :(得分:0)

非常感谢@ImportanceOfBeingEarnest,这是对我有用的最终代码(以防对他人有用)。

class tensor_plot:
    def __init__(self, tensor_shape, nrows=1):
        self.img_height, self.img_width, self.num_imgs = tensor_shape
        self.nrows = nrows
        self.ncols = self.num_imgs // nrows
        assert(self.ncols*self.nrows == self.num_imgs)
        self.fig, self.a = plt.subplots(self.nrows, self.ncols, sharex='col', sharey='row')

        self.imgs = np.array( [   [ self.a[row, col].imshow(np.zeros([self.img_height, self.img_width])) for col in range(self.ncols)    ] for row in range(self.nrows)])
        plt.pause(0.1)


    def update(self, tensor):
        n=0
        for row in range(self.nrows):
            for col in range(self.ncols):
                self.imgs[row,col].set_data(tensor[:,:,n].squeeze())
                self.imgs[row,col].set_clim(vmin=0, vmax=255)
                n += 1
        self.fig.canvas.draw_idle()
        plt.pause(0.01)
        plt.draw_all()