MNIST的平均图片

时间:2019-03-03 03:41:08

标签: python numpy mnist

使用MNIST数据集,我试图找到每个不同数字(0-9)的平均图像。以下代码为我提供了数据集中的每个不同图像,但我不确定如何获得每个类的平均值(0-9)

data = io.loadmat('mnist-original.mat')

x, y = data['data'].T, data['label'].T

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5)


a=np.unique(y, return_index=True)
b = a[1]

plt.figure(figsize=(15,4.5))
for i in b:
    img=x[i][:].reshape(28,28)
    plt.imshow(img)
    plt.show()  

2 个答案:

答案 0 :(得分:2)

假设零的“平均”图像是标签= 0的所有训练数据的平均值。 例如:

avgImg = np.average(x_train[y_train==0],0)

我认为这就是您想要的:

import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(10,3))
for i in range(10):
    avgImg = np.average(x_train[y_train==i],0)
    plt.subplot(2, 5, i+1)
    plt.imshow(avgImg.reshape((16,16))) 
    plt.axis('off')

答案 1 :(得分:0)

numpy_indexed软件包(免责声明:我是它的作者)以向量化的方式提供了这种类型的功能:

import numpy_indexed as npi
digits, means = npi.group_by(y).mean(x)