为什么我的CNN模型过度拟合?

时间:2017-10-10 14:10:52

标签: machine-learning computer-vision keras

我正在使用Keras训练CNN模型,如下所示进行OCR任务。它有46个课程,总共有78,000个例子。每个班级都有相同的号码。例子。看到,验证错误正在上升,快速搜索显示该模型过度拟合。所以,我添加了dropout图层并删除了一些图层。我用它来测试它稍微限制了过度拟合,但在一些时代之后仍然模型过度。我测试了一些变化,但趋势是相似的,验证准确性似乎停在〜.02和验证错误变得更糟。任何帮助将不胜感激。

代码:

model = Sequential()
model.add(Conv2D(4, (3, 3), activation='relu', input_shape=(28, 28, 3)))
model.add(Conv2D(8, (3, 3), activation='relu', strides=(1,1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(144, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(46, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer="Adadelta", metrics=['accuracy'])

培训结果:

62560/62560 [==============================] - 6s - loss: 15.5154 - acc: 0.0305 - val_loss: 16.1079 - val_acc: 6.3939e-05
Epoch 2/1000
62560/62560 [==============================] - 2s - loss: 15.1607 - acc: 0.0501 - val_loss: 15.7795 - val_acc: 0.0204
Epoch 3/1000
62560/62560 [==============================] - 2s - loss: 14.6713 - acc: 0.0772 - val_loss: 15.7587 - val_acc: 0.0212
Epoch 4/1000
62560/62560 [==============================] - 2s - loss: 14.2172 - acc: 0.0994 - val_loss: 15.7280 - val_acc: 0.0217
Epoch 5/1000
62560/62560 [==============================] - 2s - loss: 13.7012 - acc: 0.1225 - val_loss: 15.7533 - val_acc: 0.0205
Epoch 6/1000
62560/62560 [==============================] - 2s - loss: 13.4010 - acc: 0.1306 - val_loss: 15.7496 - val_acc: 0.0208
Epoch 7/1000
62560/62560 [==============================] - 3s - loss: 11.4178 - acc: 0.1267 - val_loss: 5.9195 - val_acc: 0.0196
Epoch 8/1000
62560/62560 [==============================] - 2s - loss: 4.0395 - acc: 0.0515 - val_loss: 4.6731 - val_acc: 0.0198
Epoch 9/1000
62560/62560 [==============================] - 2s - loss: 3.8222 - acc: 0.0458 - val_loss: 4.4824 - val_acc: 0.0192
Epoch 10/1000
62560/62560 [==============================] - 2s - loss: 3.7936 - acc: 0.0485 - val_loss: 4.6039 - val_acc: 0.0199
Epoch 11/1000
62560/62560 [==============================] - 2s - loss: 3.7754 - acc: 0.0495 - val_loss: 4.5338 - val_acc: 0.0198
Epoch 12/1000
62560/62560 [==============================] - 2s - loss: 3.7656 - acc: 0.0513 - val_loss: 4.6942 - val_acc: 0.0203
Epoch 13/1000
62560/62560 [==============================] - 2s - loss: 3.7504 - acc: 0.0535 - val_loss: 4.6317 - val_acc: 0.0202
Epoch 14/1000
62560/62560 [==============================] - 2s - loss: 3.7448 - acc: 0.0530 - val_loss: 4.7129 - val_acc: 0.0200
Epoch 15/1000
62560/62560 [==============================] - 2s - loss: 3.7377 - acc: 0.0562 - val_loss: 4.6958 - val_acc: 0.0205
Epoch 16/1000
62560/62560 [==============================] - 2s - loss: 3.7269 - acc: 0.0600 - val_loss: 4.9782 - val_acc: 0.0207
Epoch 17/1000
62560/62560 [==============================] - 2s - loss: 3.7193 - acc: 0.0606 - val_loss: 4.7774 - val_acc: 0.0206
Epoch 18/1000
62560/62560 [==============================] - 2s - loss: 3.7079 - acc: 0.0630 - val_loss: 4.8615 - val_acc: 0.0205
Epoch 19/1000
62560/62560 [==============================] - 2s - loss: 3.7000 - acc: 0.0658 - val_loss: 4.8694 - val_acc: 0.0205
Epoch 20/1000
62560/62560 [==============================] - 2s - loss: 3.6911 - acc: 0.0684 - val_loss: 5.0777 - val_acc: 0.0205
Epoch 21/1000
62560/62560 [==============================] - 2s - loss: 3.6821 - acc: 0.0713 - val_loss: 4.9727 - val_acc: 0.0204
Epoch 22/1000
62560/62560 [==============================] - 2s - loss: 3.6659 - acc: 0.0754 - val_loss: 4.9894 - val_acc: 0.0204
Epoch 23/1000
62560/62560 [==============================] - 2s - loss: 3.6528 - acc: 0.0784 - val_loss: 5.1009 - val_acc: 0.0206
Epoch 24/1000
62560/62560 [==============================] - 2s - loss: 3.6439 - acc: 0.0800 - val_loss: 6.0815 - val_acc: 0.0212
Epoch 25/1000
62560/62560 [==============================] - 2s - loss: 3.6384 - acc: 0.0832 - val_loss: 5.4393 - val_acc: 0.0205
Epoch 26/1000
62560/62560 [==============================] - 2s - loss: 3.6113 - acc: 0.0883 - val_loss: 5.4142 - val_acc: 0.0205
Epoch 27/1000
62560/62560 [==============================] - 2s - loss: 3.5986 - acc: 0.0927 - val_loss: 5.3680 - val_acc: 0.0206
Epoch 28/1000
62560/62560 [==============================] - 2s - loss: 3.5859 - acc: 0.0945 - val_loss: 5.2954 - val_acc: 0.0206
Epoch 29/1000
62560/62560 [==============================] - 2s - loss: 3.5925 - acc: 0.0923 - val_loss: 5.4587 - val_acc: 0.0206
Epoch 30/1000
62560/62560 [==============================] - 2s - loss: 3.5649 - acc: 0.0975 - val_loss: 5.6845 - val_acc: 0.0205
Epoch 31/1000
62560/62560 [==============================] - 2s - loss: 3.5553 - acc: 0.0995 - val_loss: 6.7532 - val_acc: 0.0196
Epoch 32/1000
62560/62560 [==============================] - 2s - loss: 3.5953 - acc: 0.1059 - val_loss: 5.8451 - val_acc: 0.0206
Epoch 33/1000
62560/62560 [==============================] - 2s - loss: 3.5231 - acc: 0.1065 - val_loss: 5.9717 - val_acc: 0.0205
Epoch 34/1000
62560/62560 [==============================] - 2s - loss: 3.5117 - acc: 0.1091 - val_loss: 6.2294 - val_acc: 0.0205
Epoch 35/1000
62560/62560 [==============================] - 2s - loss: 3.5055 - acc: 0.1108 - val_loss: 6.0856 - val_acc: 0.0203
Epoch 36/1000
62560/62560 [==============================] - 2s - loss: 3.4875 - acc: 0.1130 - val_loss: 6.3182 - val_acc: 0.0207
Epoch 37/1000
62560/62560 [==============================] - 2s - loss: 3.4788 - acc: 0.1151 - val_loss: 6.2881 - val_acc: 0.0205
Epoch 38/1000
62560/62560 [==============================] - 2s - loss: 3.4838 - acc: 0.1141 - val_loss: 6.3116 - val_acc: 0.0205
Epoch 39/1000
62560/62560 [==============================] - 2s - loss: 3.4705 - acc: 0.1181 - val_loss: 6.3390 - val_acc: 0.0205
Epoch 40/1000
62560/62560 [==============================] - 2s - loss: 3.4545 - acc: 0.1207 - val_loss: 6.5663 - val_acc: 0.0206
Epoch 41/1000
62560/62560 [==============================] - 2s - loss: 3.4555 - acc: 0.1201 - val_loss: 6.4602 - val_acc: 0.0209
Epoch 42/1000
62560/62560 [==============================] - 2s - loss: 3.4315 - acc: 0.1246 - val_loss: 6.3524 - val_acc: 0.0206
Epoch 43/1000
62560/62560 [==============================] - 2s - loss: 3.4235 - acc: 0.1266 - val_loss: 6.6556 - val_acc: 0.0205
Epoch 44/1000
62560/62560 [==============================] - 2s - loss: 3.4294 - acc: 0.1279 - val_loss: 6.5271 - val_acc: 0.0207
Epoch 45/1000
62560/62560 [==============================] - 2s - loss: 3.4460 - acc: 0.1287 - val_loss: 6.8675 - val_acc: 0.0207
Epoch 46/1000
62560/62560 [==============================] - 2s - loss: 3.3956 - acc: 0.1305 - val_loss: 6.5386 - val_acc: 0.0208
Epoch 47/1000
62560/62560 [==============================] - 2s - loss: 3.3859 - acc: 0.1328 - val_loss: 6.8650 - val_acc: 0.0207
Epoch 48/1000
62560/62560 [==============================] - 2s - loss: 3.3656 - acc: 0.1361 - val_loss: 6.9698 - val_acc: 0.0207
Epoch 49/1000
62560/62560 [==============================] - 2s - loss: 3.3639 - acc: 0.1377 - val_loss: 7.2205 - val_acc: 0.0208
Epoch 50/1000
62560/62560 [==============================] - 2s - loss: 3.3570 - acc: 0.1390 - val_loss: 7.6807 - val_acc: 0.0212
Epoch 51/1000
62560/62560 [==============================] - 2s - loss: 3.3579 - acc: 0.1385 - val_loss: 7.1617 - val_acc: 0.0208
Epoch 52/1000
62560/62560 [==============================] - 2s - loss: 3.3636 - acc: 0.1322 - val_loss: 7.0801 - val_acc: 0.0208
Epoch 53/1000
62560/62560 [==============================] - 2s - loss: 3.3642 - acc: 0.1319 - val_loss: 7.0819 - val_acc: 0.0208
Epoch 54/1000
62560/62560 [==============================] - 2s - loss: 3.3558 - acc: 0.1325 - val_loss: 7.2601 - val_acc: 0.0208
Epoch 55/1000
62560/62560 [==============================] - 2s - loss: 3.3486 - acc: 0.1348 - val_loss: 7.0712 - val_acc: 0.0208
Epoch 56/1000
62560/62560 [==============================] - 2s - loss: 3.3403 - acc: 0.1334 - val_loss: 7.5916 - val_acc: 0.0207

1 个答案:

答案 0 :(得分:0)

首先,您是否先规范了数据? (由djk47463指出)。

关于你的验证结果,假设你的分类器除了在你的46个类中有统一的随机猜测之外什么都没做:它应该是1/4 = 0.021的时间。这大概就是你得到的。因此,假设您的模型没有学习,这有点安全。

如果您的数据集已正确缩放(例如,使用渠道方式为零均值和1 stddev),那么我认为您的模型太小

查看参考mnist模型的示例: https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

卷积部分有大约20k参数,完全连接层有1.2kk参数。

相比之下,您的模型有:

  • 从第一层(3 * 3 * 3 + 1)* 4 = 112个参数(四个3x3x3滤波器和每个滤波器一个偏置)
  • 从第二层(3 * 3 * 4 + 1)* 8 = 296个参数(8个3x3x4滤波器和每个滤波器一个偏置)
  • 来自完全连接的层172k参数

因此,卷积层中只有408个参数,完全连接层中只有172k参数。

你可能没有那么大的完全连接的层,但我确定卷积参数计数增加了至少10倍。

<强> TLDR ; 尝试添加更多卷积层,每个层都有更多过滤器,例如:

Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 3))