我对基于mnist数据的验证集上的keras模型的性能感到困惑。
我只是简短地使用测试数据,我从此处将其作为csv文件下载:https://pjreddie.com/projects/mnist-in-csv/
我的代码如下:
mnist = pd.read_csv('mnist_test.csv', header = None)
mnist.head()
0 1 2 3 4 5 6 7 8 9 ... 775 776 777 778 779 780 781 782 783 784
0 7 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
1 2 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
2 1 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
3 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
4 4 0 0 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
X = mnist.iloc[:, 1:].values
y = to_categorical(mnist.iloc[:, 0])
n_cols = X.shape[1]
# Create the model: model
model = Sequential()
# Add the first hidden layer
model.add(Dense(50, activation = 'relu', input_shape = (784,)))
# Add the second hidden layer
model.add(Dense(50, activation = 'relu'))
# Add the output layer
model.add(Dense(10, activation = 'softmax'))
# Compile the model
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
# Fit the model
model.fit(X, y, validation_split = 0.3)
输出为:
Train on 7000 samples, validate on 3000 samples
Epoch 1/1
7000/7000 [==============================] - 1s 109us/step - loss: 10.9961 - acc: 0.3111 - val_loss: 10.2264 - val_acc: 0.3637
验证集中的精度为0.36?真是难以置信。
出了什么问题?
答案 0 :(得分:0)
尝试做更多的训练。
因此在model.fit命令中插入“ epochs = 200”。
答案 1 :(得分:0)
向您的代码添加规范化。然后应该可以正常工作
X = X.astype('float32')
y = y.astype('float32')
X /= 255
即
X = mnist_test.iloc[:, 1:].values
y = to_categorical(mnist_test.iloc[:, 0])
X = X.astype('float32')
y = y.astype('float32')
X /= 255
n_cols = X.shape[1]
# Create the model: model
model = Sequential()
# Add the first hidden layer
model.add(Dense(50, activation = 'relu', input_shape = (784,)))
# Add the second hidden layer
model.add(Dense(50, activation = 'relu'))
# Add the output layer
model.add(Dense(10, activation = 'softmax'))
# Compile the model
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])
# Fit the model
model.fit(X, y, validation_split = 0.3, epochs=10)
经过测试,给我:
Train on 7000 samples, validate on 3000 samples
Epoch 1/10
7000/7000 [==============================] - 1s 179us/step - loss: 0.7758 - acc: 0.7734 - val_loss: 0.3359 - val_acc: 0.9073
Epoch 2/10
7000/7000 [==============================] - 1s 137us/step - loss: 0.3104 - acc: 0.9056 - val_loss: 0.2225 - val_acc: 0.9330
Epoch 3/10
7000/7000 [==============================] - 1s 133us/step - loss: 0.2291 - acc: 0.9339 - val_loss: 0.1958 - val_acc: 0.9390
Epoch 4/10
7000/7000 [==============================] - 1s 138us/step - loss: 0.1845 - acc: 0.9461 - val_loss: 0.1827 - val_acc: 0.9433
Epoch 5/10
7000/7000 [==============================] - 1s 138us/step - loss: 0.1509 - acc: 0.9571 - val_loss: 0.1678 - val_acc: 0.9483
Epoch 6/10
7000/7000 [==============================] - 1s 143us/step - loss: 0.1240 - acc: 0.9641 - val_loss: 0.1760 - val_acc: 0.9407
Epoch 7/10
7000/7000 [==============================] - 1s 136us/step - loss: 0.1012 - acc: 0.9710 - val_loss: 0.1801 - val_acc: 0.9453
Epoch 8/10
7000/7000 [==============================] - 1s 138us/step - loss: 0.0838 - acc: 0.9761 - val_loss: 0.1867 - val_acc: 0.9457
Epoch 9/10
7000/7000 [==============================] - 1s 132us/step - loss: 0.0697 - acc: 0.9780 - val_loss: 0.1820 - val_acc: 0.9450
Epoch 10/10
7000/7000 [==============================] - 1s 135us/step - loss: 0.0559 - acc: 0.9843 - val_loss: 0.1699 - val_acc: 0.9493