绘制MNIST样本

时间:2018-12-18 16:42:30

标签: python matplotlib data-science mnist

我正在尝试从MNIST数据集中绘制10个样本。每个数字之一。这是代码:

import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data

for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    print(im_idx)
    plottable_image = np.reshape(X[im_idx], (28, 28))
    plt.imshow(plottable_image, cmap='gray_r')
    plt.subplot(2, 5, i + 1)

plt.plot()

由于某些原因,图中的数字被跳过了。

为什么?

2 个答案:

答案 0 :(得分:3)

好,知道了。问题是您在绘制imshow之后定义了子图。因此,您的第一个子图被第二个图覆盖。为了使您的代码正常工作,只需按以下步骤交换两个命令的顺序即可。另外,我不明白您为什么最后使用plt.plot()

plt.subplot(2, 5, i + 1) # <-- You have put this command after imshow 
plt.imshow(plottable_image, cmap='gray_r')

以下是您所知的另一种替代方法:

fig = plt.figure()

for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax = fig.add_subplot(2, 5, i+1)
    ax.imshow(plottable_image, cmap='gray_r')

您还可以使用以下方法进一步缩短Scott的代码(如下所述):

fig, ax = plt.subplots(2,5)
for i, ax in enumerate(ax.flatten()):
    im_idx = np.argwhere(y == i)[0]
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax.imshow(plottable_image, cmap='gray_r')

enter image description here

答案 1 :(得分:2)

尝试一下:

import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets

mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data

fig, ax = plt.subplots(2,5)
ax = ax.flatten()
for i in range(10):
    im_idx = np.argwhere(y == i)[0]
    print(im_idx)
    plottable_image = np.reshape(X[im_idx], (28, 28))
    ax[i].imshow(plottable_image, cmap='gray_r')

输出:

enter image description here