我正在针对MNIST数据测试简单的深度学习代码,但出现错误,我不确定为什么。以下代码摘自Francois Chollet的《使用Python进行深度学习》一书:
from keras.datasets import mnist
from keras import models
from keras import layers
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32') / 255
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
network.add(layers.Dense(10, activation = 'softmax'))
network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
network.fit(train_images, train_labels, epochs=5, batch_size=128)
我遇到以下错误:
ValueError Traceback (most recent call last)
<ipython-input-9-fb9fd206ece1> in <module>
18 network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
19
---> 20 network.fit(train_images, train_labels, epochs=5, batch_size=128)
~/.local/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
950 sample_weight=sample_weight,
951 class_weight=class_weight,
--> 952 batch_size=batch_size)
953 # Prepare validation data.
954 do_validation = False
~/.local/lib/python3.7/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
787 feed_output_shapes,
788 check_batch_axis=False, # Don't enforce the batch size.
--> 789 exception_prefix='target')
790
791 # Generate sample-wise weight values given the `sample_weight` and
~/.local/lib/python3.7/site-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
136 ': expected ' + names[i] + ' to have shape ' +
137 str(shape) + ' but got array with shape ' +
--> 138 str(data_shape))
139 return data
140
ValueError: Error when checking target: expected dense_9 to have shape (10,) but got array with shape (1,)
答案 0 :(得分:3)
您的标签数组具有形状(大约1),而您的模型需要形状(大约10)的数组。例如,您需要使用keras.utils.to_categorical
将标签数组转换为分类数组。像这样:
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)