如何预测在Keras中使用标准化或辍学层进行训练时?

时间:2017-09-04 07:27:20

标签: python machine-learning keras caffe conv-neural-network

我训练了一个带有标准化层的模型。代码如下:

在培训阶段:

model=Sequential()
model.add()

...

k.set_learning_phase(1)
ModelCheckpoint(weights_file)
model.fit()

在推理时间:

k.set_learning_phase(0)
model.load_weights(weights_file)
model.predict_classes()

...

Keras的版本:2.0.8。这是正确的,还是需要一些特殊的代码来训练后像在Caffe中使用SegNet一样计算BN?

1 个答案:

答案 0 :(得分:2)

不,在使用BatchNormalization或Dropout图层时,您不需要做任何特殊操作。 Keras已经跟踪了学习/测试阶段,因此在使用predictpredict_classes时,它会做正确的事情。

您甚至不需要手动设置学习阶段,Keras已经完成了。