(Keras版本2.3)在BatchNormalization中设置Trainable = False不起作用

时间:2019-12-06 09:05:10

标签: python tensorflow keras

Keras版本:2.3

Tensorflow后端:1.14

我正在使用ResNet的Keras实现,在此实现中,我已将BatchNormalization层的训练标志设置为False。

我已经了解到,当可训练标记设置为False且该层以100%推断模式运行时,从Keras 2.1.3版及以后的版本中将不学习批处理统计信息。

来源: https://github.com/keras-team/keras/releases/tag/2.1.3

我进行了ResNet模型的训练,两次训练之间有一种差异。在其中之一中,我使用以下命令设置了learning_phase K.set_learning_phase(0)(谓词或推理模式)和另一个K.set_learning_phase(1)(训练模式)。

我从两者中得到不同的损失,幅度也完全不同。例如:使用K.set_learning_phase(1)(训练模式)时,损耗为:5.6274 而另一种推断是:3.2310。

我在两个测试中都使用了相同的数据。没有遗漏,没有正规化可能影响模型中的损失。据我认为,没有其他设置可以导致损失值出现这种差异。

为了确保可训练的Flag正常工作,我还比较了这两种模型,并且更改设置后参数的数量也会发生变化。

Number of Parameters with setting Trainable = True in BatchNormalization

Number of Parameters with setting Trainable=False in BatchNormalization

我希望无论set_learning_phase值如何,当可训练标志设置为False时,模型都应做出类似响应。至少从更改日志中可以了解到这一点,当标志设置为False时,层以100%推断模式运行。

任何了解这种行为发生原因的见解都是很棒的!

0 个答案:

没有答案