绘制MNIST数字有问题

时间:2019-05-06 11:47:10

标签: python matplotlib computer-vision mnist

我正在尝试加载和可视化MNIST数字,但是我得到的像素移位了

import matplotlib.pyplot as plt
import numpy as np

mnist_data  = open('data/mnist/train-images-idx3-ubyte', 'rb')

image_size = 28
num_images = 4

buf = mnist_data.read(num_images * image_size * image_size)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = data.reshape(num_images, image_size, image_size)

_, axarr1 = plt.subplots(2,2)
axarr1[0, 0].imshow(data[0])
axarr1[0, 1].imshow(data[1])
axarr1[1, 0].imshow(data[2])
axarr1[1, 1].imshow(data[3])

MNIST

谁能告诉我为什么它发生的代码看起来不错,谢谢

1 个答案:

答案 0 :(得分:1)

您没有说在哪里获得MNIST数据,但是if it is formatted like the original data set,您似乎忘记了尝试访问数据之前先提取标头:

image_size = 28
num_images = 4

mnist_data = open('train-images-idx3-ubyte', 'rb')

mnist_data.seek(16) # skip over the first 16 bytes that correspond to the header
buf = mnist_data.read(num_images * image_size * image_size)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = data.reshape(num_images, image_size, image_size)

_, axarr1 = plt.subplots(2,2)
axarr1[0, 0].imshow(data[0])
axarr1[0, 1].imshow(data[1])
axarr1[1, 0].imshow(data[2])
axarr1[1, 1].imshow(data[3])

enter image description here