试图访问pytorch中的mnist数据集的子集[每个类中的相等样本],但出现此错误
prng = RandomState(42)
random_permute = prng.permutation(np.arange(0, 6000))[0:3000]
indx = np.concatenate([np.where(np.array(mnist_data.targets) == classe)[0][random_permute] for classe in range(0,10)])
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-178-038015f76b77> in <module>
----> 1 indx = np.concatenate([np.where(np.array(mnist_data.targets) == classe)[0][random_permute] for classe in range(0,10)])
<ipython-input-178-038015f76b77> in <listcomp>(.0)
----> 1 indx = np.concatenate([np.where(np.array(mnist_data.targets) == classe)[0][random_permute] for classe in range(0,10)])
IndexError: index 5992 is out of bounds for axis 0 with size 5923
答案 0 :(得分:0)
MNIST数据集的目标分布不均匀。您收到此错误是因为MNIST中的类0包含5923个样本。
nums = [0]*10
for i in range(60000):
nums[(int(mnist_data.targets[i]))] += 1
print(nums)
这将打印[5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]
。