带有简单MNIST数据示例的深度学习错误

时间:2019-07-12 11:44:36

标签: python keras deep-learning

我正在针对MNIST数据测试简单的深度学习代码,但出现错误,我不确定为什么。以下代码摘自Francois Chollet的《使用Python进行深度学习》一书:

from keras.datasets import mnist
from keras import models
from keras import layers

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255

test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32') / 255 

network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
network.add(layers.Dense(10, activation = 'softmax'))

network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

network.fit(train_images, train_labels, epochs=5, batch_size=128)

我遇到以下错误:

ValueError                                Traceback (most recent call last)
<ipython-input-9-fb9fd206ece1> in <module>
     18 network.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
     19 
---> 20 network.fit(train_images, train_labels, epochs=5, batch_size=128)

~/.local/lib/python3.7/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
    950             sample_weight=sample_weight,
    951             class_weight=class_weight,
--> 952             batch_size=batch_size)
    953         # Prepare validation data.
    954         do_validation = False

~/.local/lib/python3.7/site-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    787                 feed_output_shapes,
    788                 check_batch_axis=False,  # Don't enforce the batch size.
--> 789                 exception_prefix='target')
    790 
    791             # Generate sample-wise weight values given the `sample_weight` and

~/.local/lib/python3.7/site-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    136                             ': expected ' + names[i] + ' to have shape ' +
    137                             str(shape) + ' but got array with shape ' +
--> 138                             str(data_shape))
    139     return data
    140 

ValueError: Error when checking target: expected dense_9 to have shape (10,) but got array with shape (1,)

1 个答案:

答案 0 :(得分:3)

您的标签数组具有形状(大约1),而您的模型需要形状(大约10)的数组。例如,您需要使用keras.utils.to_categorical将标签数组转换为分类数组。像这样:

from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)