Keras版本:2.3
Tensorflow后端:1.14
我正在使用ResNet的Keras实现,在此实现中,我已将BatchNormalization层的训练标志设置为False。
我已经了解到,当可训练标记设置为False且该层以100%推断模式运行时,从Keras 2.1.3版及以后的版本中将不学习批处理统计信息。
来源: https://github.com/keras-team/keras/releases/tag/2.1.3
我进行了ResNet模型的训练,两次训练之间有一种差异。在其中之一中,我使用以下命令设置了learning_phase
K.set_learning_phase(0)
(谓词或推理模式)和另一个K.set_learning_phase(1)
(训练模式)。
我从两者中得到不同的损失,幅度也完全不同。例如:使用K.set_learning_phase(1)
(训练模式)时,损耗为:5.6274
而另一种推断是:3.2310。
我在两个测试中都使用了相同的数据。没有遗漏,没有正规化可能影响模型中的损失。据我认为,没有其他设置可以导致损失值出现这种差异。
为了确保可训练的Flag正常工作,我还比较了这两种模型,并且更改设置后参数的数量也会发生变化。
Number of Parameters with setting Trainable = True in BatchNormalization
Number of Parameters with setting Trainable=False in BatchNormalization
我希望无论set_learning_phase值如何,当可训练标志设置为False时,模型都应做出类似响应。至少从更改日志中可以了解到这一点,当标志设置为False时,层以100%推断模式运行。
任何了解这种行为发生原因的见解都是很棒的!