用于子图的pyplot绘制方法

时间:2015-04-19 23:38:53

标签: python matplotlib iteration subplot

我有10个数字0-9的图像,每个图像包含28x28像素,形状X的数组(28**2, 10)

我正在使用循环中的新像素更新X,我想在每次迭代时更新我的​​绘图。

目前,我的代码将创建100个单独的数字。

def plot_output(X):
    """grayscale images of the digits 0-9
    in 28x28 pixels in pyplot

    Input, X is of shape (28^2, 10)
    """
    n = X.shape[1] # number of digits
    pixels = (28,28) # pixel shape
    fig, ax = plt.subplots(1,n)

    # cycle through digits from 0-9
    # X input array is reshaped for each 10 digits
    # to a (28,28) vector to plot
    for i in range(n):
        wi=X[:,i] # w = weights for digit

        wi=wi.reshape(*pixels)
        ax[i].imshow(wi,cmap=plt.cm.gist_gray,
            interpolation='gaussian', aspect='equal')
        ax[i].axis('off')
        ax[i].set_title('{0:0d}'.format(i))

    plt.tick_params(axis='x', which='both', bottom='off',
        top='off', labelbottom='off')

    plt.show()

for i in range(100):
    X = init_pix() # anything that generates a (728, 10) array
    plot_output(X)

我尝试过使用plt.draw()pt.canvas.draw(),但我似乎无法正确实现它。我也试过了plt.clf()这对我也不起作用。

我能够使用this post使用直线和一个轴来做到这一点,但我不能让它在子图上工作。

2 个答案:

答案 0 :(得分:2)

通过使用plt.ion(),你可以制作plt.show()命令,通常是阻止,而不是阻止。

然后您可以使用imshow更新轴,并且它们会在计算出来时出现在图中。

例如:

import numpy as np
import matplotlib.pyplot as plt

n=10

X = np.random.rand(28**2,n)

fig, ax = plt.subplots(1,n)

plt.ion()
plt.show()

for i in range(n):
    wi = X[:,1].reshape(28,28)
    ax[i].imshow(wi)

    #fig.canvas.draw()  # May be necessary, wasn't for me.

plt.ioff()  # Make sure to make plt.show() blocking again, otherwise it'll run
plt.show()  #   right through this and immediately close the window (if program exits)

你现在会得到丑陋的巨大的空白轴,直到你的轴被定义,但这应该让你开始。

答案 1 :(得分:0)

我通过创建一个绘图类并在每个轴上使用.cla()然后使用imshow()

重新定义每个轴来找到解决方案
class plot_output(object):

    def __init__(self, X):
        """grayscale images of the digits 1-9
        """
        self.X = X
        self.n = X.shape[1] # number of digits
        self.pixels = (25,25) # pixel shape
        self.fig, self.ax = plt.subplots(1,self.n)
        plt.ion()

        # cycle through digits from 0-9
        # X input vector is reshaped for each 10 digits
        # to a (28,28) vector to plot
        self.img_obj_ar = []

        for i in range(self.n):
            wi=X[:,i] # w = weights for digit

            wi=wi.reshape(*self.pixels)
            self.ax[i].imshow(wi,cmap=plt.cm.gist_gray,
                interpolation='gaussian', aspect='equal')
            self.ax[i].axis('off')
            self.ax[i].set_title('{0:0d}'.format(i))

        plt.tick_params(\
            axis='x',          # changes apply to the x-axis
            which='both',      # both major and minor ticks are affected
            bottom='off',      # ticks along the bottom edge are off
            top='off',         # ticks along the top edge are off
            labelbottom='off')

        plt.tick_params(\
            axis='y',          # changes apply to the y-axis
            which='both',      # both major and minor ticks are affected
            left='off', 
            right='off',    # ticks along the top edge are off
            labelleft='off')

        plt.show()

    def update(self, X):

        # cycle through digits from 0-9
        # X input vector is reshaped for each 10 digits
        # to a (28,28) vector to plot
        for i in range(self.n):
            self.ax[i].cla()
            wi=X[:,i] # w = weights for digit

            wi=wi.reshape(*self.pixels)
            self.ax[i].imshow(wi,cmap=plt.cm.gist_gray,
                            interpolation='gaussian', aspect='equal')
            self.ax[i].axis('off')
            self.ax[i].set_title('{0:0d}'.format(i))

        plt.draw()