我正在尝试为Cifar 10建立一个NN分类器,但我不断收到错误:索引错误:索引1超出了尺寸为1的轴3的范围
我没有太多的编程背景,所以这对我来说是很新的。
label_dict = {
0: 'airplane',
1: 'automobile',
2: 'bird',
3: 'cat',
4: 'deer',
5: 'dog',
6: 'frog',
7: 'horse',
8: 'ship',
9: 'truck',
}
labels_train = np.asarray([label_dict[u.argmax(0)] for u in y_train])
labels_test = np.asarray([label_dict[u.argmax(0)] for u in y_test])
def plot_images(x, labels=None, nrow=2, ncol=4, im_type='image'):
if im_type == 'filter':
x = x.transpose(3, 0, 1, 2) # Transpose to image type
fig, ax = plt.subplots(nrow, ncol, figsize=(5*ncol, 4*nrow))
num_samples = x.shape[0]
num_channels = x.shape[-1]
for a in ax.ravel():
j = np.random.choice(num_samples)
k = np.random.choice(num_channels)
sns.heatmap(x[j, :, :, k], ax=a, cbar=False, cmap='gray_r')
if labels is not None:
a.set_title(labels[j])
a.set_xticks([])
a.set_yticks([])
plot_images(x_train, labels=labels_train)`
这是我不断遇到的错误。
IndexError Traceback (most recent call last)
<ipython-input-11-f48331a29aac> in <module>
14 a.set_yticks([])
15
---> 16 plot_images(x_train, labels=labels_train)
<ipython-input-11-f48331a29aac> in plot_images(x, labels, nrow, ncol, im_type)
8 j = np.random.choice(num_samples)
9 k = np.random.choice(num_channels)
---> 10 sns.heatmap(x[j, :, :, k], ax=a, cbar=False, cmap='gray_r')
11 if labels is not None:
12 a.set_title(labels[j])
IndexError: index 1 is out of bounds for axis 3 with size 1