如何在keras BatchNormalization中更新移动均值和移动方差?
我在tensorflow文档中找到了这个,但我不知道在哪里放train_op
或如何使用keras模型:
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize( loss )
我找到的帖子没有说明如何处理train_op以及是否可以在model.compile
中使用它。
答案 0 :(得分:0)
如果您只需要使用一些新值更新现有模型的权重,则可以执行以下操作:
w = model.get_layer('batchnorm_layer_name').get_weights()
# Order: [gamma, beta, mean, std]
for j in range(len(w[0])):
gamma = w[0][j]
beta = w[1][j]
run_mean = w[2][j]
run_std = w[3][j]
w[2][j] = new_run_mean_value1
w[3][j] = new_run_std_value2
model.get_layer('batchnorm_layer_name').set_weights(w)
答案 1 :(得分:0)
如果使用BatchNormalization图层,则无需手动更新移动均值和方差。 Keras负责在培训期间更新这些参数,并在测试期间保持固定(使用model.predict
和model.evaluate
功能,与model.fit_generator
和朋友相同。
Keras还会跟踪学习阶段,以便在培训和验证/测试期间运行不同的代码路径。
答案 2 :(得分:0)
对该问题有两种解释:第一种是假设目标是使用高级培训api,而Matias Valdenegro回答了这个问题。
第二点(如评论中所述)是是否可以通过此处keras a simplified tensorflow interface和“收集可训练的权重和状态更新”部分讨论的标准张量流优化器使用批处理规范化。如前所述,更新操作可在layer.updates中访问,而不能在$ prebuilts/sdk/tools/jack-admin kill-server
$ prebuilts/sdk/tools/jack-admin start-server
中访问,实际上,如果您在tensorflow中具有keras模型,则可以使用标准tensorflow优化器和批处理规范化进行优化
tf.GraphKeys.UPDATE_OPS
,然后使用张量流会话获取train_op。为了区分批次归一化层的训练和评估模式,您需要输入 keras引擎的学习阶段状态(请参见上面给出的同一tutorial page上的“训练和测试期间的不同行为”)。例如,这样可以工作
update_ops = model.updates
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize( loss )
我在tensorflow 1.12中尝试了此方法,它适用于包含批处理规范化的模型。鉴于我现有的tensorflow代码,并且鉴于即将接近tensorflow 2.0版,我很想亲自使用此方法,但是鉴于tensorflow文档中未提及该方法,因此我不确定长期是否会支持该方法,我最终决定不使用它,并投入更多的钱来更改代码以使用高级api。