我有一套32把椅子和32把灯,形成一套训练。我通过Keras的一个简单的CNN运行它们,其目标是区分椅子和灯。 调用model.fit()时,根据打印输出,我的精度接近1.0。然而,在训练之后,对训练数据的model.evaluate()结果的准确度为〜0.5。
这是我的代码:
# train_tensors, train_labels contain training data
model = keras.models.Sequential()
model.add(Conv2D(filters=5,
kernel_size=[4, 4],
strides=2,
padding='same',
input_shape=[225, 225, 3]))
model.add(LeakyReLU(0.2))
model.add(Conv2D(filters=10,
kernel_size=[4, 4],
strides=2,
padding='same'))
model.add(BatchNorm(axis=3)) # imported BatchNormalization as BatchNorm
model.add(LeakyReLU(0.2))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.fit(train_tensors, train_labels, batch_size=8, epochs=5, shuffle=True)
metrics = model.evaluate(train_tensors, train_labels)
print('')
print(np.ravel(model.predict(train_tensors)))
print('training data results: ')
for i in range(len(model.metrics_names)):
print(str(model.metrics_names[i]) + ": " + str(metrics[i]))
train_tensors
用于model.fit()和model.evaluate()。打印输出是:
Epoch 1/5
8/64 [==>...........................] - ETA: 0s - loss: 0.6927 - acc: 0.5000
16/64 [======>.......................] - ETA: 0s - loss: 1.1480 - acc: 0.5000
24/64 [==========>...................] - ETA: 0s - loss: 0.8194 - acc: 0.6667
40/64 [=================>............] - ETA: 0s - loss: 0.5205 - acc: 0.8000
56/64 [=========================>....] - ETA: 0s - loss: 0.4896 - acc: 0.8036
64/64 [==============================] - 0s - loss: 0.4504 - acc: 0.8125
Epoch 2/5
8/64 [==>...........................] - ETA: 0s - loss: 0.0082 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 0.0391 - acc: 1.0000
40/64 [=================>............] - ETA: 0s - loss: 0.0444 - acc: 0.9750
56/64 [=========================>....] - ETA: 0s - loss: 0.0392 - acc: 0.9821
64/64 [==============================] - 0s - loss: 0.0382 - acc: 0.9844
Epoch 3/5
8/64 [==>...........................] - ETA: 0s - loss: 4.8843e-04 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 7.5668e-04 - acc: 1.0000
40/64 [=================>............] - ETA: 0s - loss: 6.1193e-04 - acc: 1.0000
56/64 [=========================>....] - ETA: 0s - loss: 0.0096 - acc: 1.0000
64/64 [==============================] - 0s - loss: 0.0128 - acc: 1.0000
Epoch 4/5
8/64 [==>...........................] - ETA: 0s - loss: 9.2490e-04 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 9.6854e-04 - acc: 1.0000
40/64 [=================>............] - ETA: 0s - loss: 9.6813e-04 - acc: 1.0000
56/64 [=========================>....] - ETA: 0s - loss: 7.5456e-04 - acc: 1.0000
64/64 [==============================] - 0s - loss: 8.3200e-04 - acc: 1.0000
Epoch 5/5
8/64 [==>...........................] - ETA: 0s - loss: 4.2928e-04 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 0.0044 - acc: 1.0000
40/64 [=================>............] - ETA: 0s - loss: 0.0027 - acc: 1.0000
56/64 [=========================>....] - ETA: 0s - loss: 0.0026 - acc: 1.0000
64/64 [==============================] - 0s - loss: 0.0024 - acc: 1.0000
32/64 [==============>...............] - ETA: 0s
64/64 [==============================] - 0s
[ 2.20039312e-21 3.70743738e-15 7.76885543e-08 3.38629164e-20
1.26636347e-14 8.46270983e-23 4.83105518e-24 1.63172146e-07
3.59334761e-28 7.74249325e-20 6.30969798e-28 2.79597981e-12
1.17927814e-21 3.84340554e-01 3.83124183e-23 4.88756598e-07
8.28199488e-27 3.89127730e-16 7.77586222e-32 2.96250031e-21
1.51558620e-22 3.26927439e-12 1.96537564e-20 2.68915438e-16
2.90332289e-17 1.78180949e-03 6.45235020e-23 2.82894642e-25
9.87989724e-01 5.52072190e-02 6.61221920e-31 6.48611497e-29
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 2.91474397e-38
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
1.56358186e-38 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
training data results:
loss: 7.17852215999
acc: 0.515625
model.fit()期间的准确性与model.evaluate()的准确性之间存在巨大差异,即0.515625。为什么会这样?
答案 0 :(得分:2)
问题在于BatchNormalization
图层。正如您在评论中指出的那样,删除图层会使模型有效。这里的问题是,当训练与测试时,batchnorm的工作方式不同。在列车时间,它使用按批次计算的统计数据。作为测试时间,它使用在整个训练过程中计算的统计量(均值/ stdev)作为整个训练集的运行平均值。由于您的培训时间很短,因此很可能无法准确完成统计数据。
您可以(a)删除batchnorm图层,正如我在评论中提到的那样似乎有效。或者(b)通过将batchnorm层中的参数momentum
调整到较低值来增加计算移动平均值/标准的速率。尝试[0.5-0.95]