keras的model.fit()与model.evaluate()不一致

时间:2017-03-27 19:53:54

标签: python keras

我有一套32把椅子和32把灯,形成一套训练。我通过Keras的一个简单的CNN运行它们,其目标是区分椅子和灯。 调用model.fit()时,根据打印输出,我的精度接近1.0。然而,在训练之后,对训练数据的model.evaluate()结果的准确度为〜0.5。

这是我的代码:

# train_tensors, train_labels contain training data
model = keras.models.Sequential()
model.add(Conv2D(filters=5,
                 kernel_size=[4, 4],
                 strides=2,
                 padding='same',
                 input_shape=[225, 225, 3]))
model.add(LeakyReLU(0.2))

model.add(Conv2D(filters=10,
                 kernel_size=[4, 4],
                 strides=2,
                 padding='same'))
model.add(BatchNorm(axis=3))  # imported BatchNormalization as BatchNorm
model.add(LeakyReLU(0.2))

model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

model.fit(train_tensors, train_labels, batch_size=8, epochs=5, shuffle=True)

metrics = model.evaluate(train_tensors, train_labels)
print('')
print(np.ravel(model.predict(train_tensors)))
print('training data results: ')
for i in range(len(model.metrics_names)):
    print(str(model.metrics_names[i]) + ": " + str(metrics[i]))

train_tensors用于model.fit()和model.evaluate()。打印输出是:

Epoch 1/5
 8/64 [==>...........................] - ETA: 0s - loss: 0.6927 - acc: 0.5000
16/64 [======>.......................] - ETA: 0s - loss: 1.1480 - acc: 0.5000
24/64 [==========>...................] - ETA: 0s - loss: 0.8194 - acc: 0.6667
40/64 [=================>............] - ETA: 0s - loss: 0.5205 - acc: 0.8000
56/64 [=========================>....] - ETA: 0s - loss: 0.4896 - acc: 0.8036
64/64 [==============================] - 0s - loss: 0.4504 - acc: 0.8125     
Epoch 2/5
 8/64 [==>...........................] - ETA: 0s - loss: 0.0082 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 0.0391 - acc: 1.0000
40/64 [=================>............] - ETA: 0s - loss: 0.0444 - acc: 0.9750
56/64 [=========================>....] - ETA: 0s - loss: 0.0392 - acc: 0.9821
64/64 [==============================] - 0s - loss: 0.0382 - acc: 0.9844     
Epoch 3/5
 8/64 [==>...........................] - ETA: 0s - loss: 4.8843e-04 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 7.5668e-04 - acc: 1.0000
40/64 [=================>............] - ETA: 0s - loss: 6.1193e-04 - acc: 1.0000
56/64 [=========================>....] - ETA: 0s - loss: 0.0096 - acc: 1.0000    
64/64 [==============================] - 0s - loss: 0.0128 - acc: 1.0000     
Epoch 4/5
 8/64 [==>...........................] - ETA: 0s - loss: 9.2490e-04 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 9.6854e-04 - acc: 1.0000
40/64 [=================>............] - ETA: 0s - loss: 9.6813e-04 - acc: 1.0000
56/64 [=========================>....] - ETA: 0s - loss: 7.5456e-04 - acc: 1.0000
64/64 [==============================] - 0s - loss: 8.3200e-04 - acc: 1.0000     
Epoch 5/5
 8/64 [==>...........................] - ETA: 0s - loss: 4.2928e-04 - acc: 1.0000
24/64 [==========>...................] - ETA: 0s - loss: 0.0044 - acc: 1.0000    
40/64 [=================>............] - ETA: 0s - loss: 0.0027 - acc: 1.0000
56/64 [=========================>....] - ETA: 0s - loss: 0.0026 - acc: 1.0000
64/64 [==============================] - 0s - loss: 0.0024 - acc: 1.0000     
32/64 [==============>...............] - ETA: 0s
64/64 [==============================] - 0s     

[  2.20039312e-21   3.70743738e-15   7.76885543e-08   3.38629164e-20
   1.26636347e-14   8.46270983e-23   4.83105518e-24   1.63172146e-07
   3.59334761e-28   7.74249325e-20   6.30969798e-28   2.79597981e-12
   1.17927814e-21   3.84340554e-01   3.83124183e-23   4.88756598e-07
   8.28199488e-27   3.89127730e-16   7.77586222e-32   2.96250031e-21
   1.51558620e-22   3.26927439e-12   1.96537564e-20   2.68915438e-16
   2.90332289e-17   1.78180949e-03   6.45235020e-23   2.82894642e-25
   9.87989724e-01   5.52072190e-02   6.61221920e-31   6.48611497e-29
   0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
   0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
   0.00000000e+00   0.00000000e+00   0.00000000e+00   2.91474397e-38
   0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
   1.56358186e-38   0.00000000e+00   0.00000000e+00   0.00000000e+00
   0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
   0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00
   0.00000000e+00   0.00000000e+00   0.00000000e+00   0.00000000e+00]
training data results: 
loss: 7.17852215999
acc: 0.515625

model.fit()期间的准确性与model.evaluate()的准确性之间存在巨大差异,即0.515625。为什么会这样?

1 个答案:

答案 0 :(得分:2)

问题在于BatchNormalization图层。正如您在评论中指出的那样,删除图层会使模型有效。这里的问题是,当训练与测试时,batchnorm的工作方式不同。在列车时间,它使用按批次计算的统计数据。作为测试时间,它使用在整个训练过程中计算的统计量(均值/ stdev)作为整个训练集的运行平均值。由于您的培训时间很短,因此很可能无法准确完成统计数据。

您可以(a)删除batchnorm图层,正如我在评论中提到的那样似乎有效。或者(b)通过将batchnorm层中的参数momentum调整到较低值来增加计算移动平均值/标准的速率。尝试[0.5-0.95]

范围内的动量