使用Python显示来自CIFAR10数据集的10个类别中的每个类别的前三张图像

时间:2018-08-28 19:10:49

标签: python keras

我已经完成了这段代码,但是它随机显示图像。如何显示10个类别中每一个的前三个图像,然后将其绘制为3 x 10图像?

from keras.datasets import cifar10
import matplotlib.pyplot as plt

nb_classes = 10
class_name = {
 0: 'airplane',
 1: 'automobile',
 2: 'bird',
 3: 'cat',
 4: 'deer',
 5: 'dog',
 6: 'frog',
 7: 'horse',
 8: 'ship',
 9: 'truck'
}

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
y_train = y_train.reshape(y_train.shape[0]) 
y_test = y_test.reshape(y_test.shape[0])

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

def draw_img(i):
  im = X_train[i]
  c = y_train[i]
  plt.imshow(im)
  plt.title("Class %d (%s)" % (c, class_name[c]))
  plt.axis('on')

 def draw_sample(X, y, n, rows=4, cols=4, imfile=None, fontsize=10):
   for i in range(0, rows*cols):
    plt.subplot(rows, cols, i+1)
    im = X[n+i].reshape(32,32,3)
    plt.imshow(im, cmap='gnuplot2')
    plt.title("{}".format(class_name[y[n+i]]), fontsize=fontsize)
    plt.axis('off')
    plt.subplots_adjust(wspace=0.8, hspace=0.01)
 if imfile:
   plt.savefig(imfile)
draw_img(0)
draw_sample(X_train, y_train,9, 3, 10)

输出: Result

0 个答案:

没有答案