我训练了一个带有标准化层的模型。代码如下:
在培训阶段:
model=Sequential()
model.add()
...
k.set_learning_phase(1)
ModelCheckpoint(weights_file)
model.fit()
在推理时间:
k.set_learning_phase(0)
model.load_weights(weights_file)
model.predict_classes()
...
Keras的版本:2.0.8。这是正确的,还是需要一些特殊的代码来训练后像在Caffe中使用SegNet一样计算BN?
答案 0 :(得分:2)
不,在使用BatchNormalization或Dropout图层时,您不需要做任何特殊操作。 Keras已经跟踪了学习/测试阶段,因此在使用predict
或predict_classes
时,它会做正确的事情。
您甚至不需要手动设置学习阶段,Keras已经完成了。