Keras中的BatchNormalization

时间:2018-05-03 22:16:26

标签: tensorflow keras batch-normalization

如何在keras BatchNormalization中更新移动均值和移动方差?

我在tensorflow文档中找到了这个,但我不知道在哪里放train_op或如何使用keras模型:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize( loss )

我找到的帖子没有说明如何处理train_op以及是否可以在model.compile中使用它。

3 个答案:

答案 0 :(得分:0)

如果您只需要使用一些新值更新现有模型的权重,则可以执行以下操作:

w = model.get_layer('batchnorm_layer_name').get_weights()
# Order: [gamma, beta, mean, std]
for j in range(len(w[0])):
    gamma = w[0][j]
    beta = w[1][j]
    run_mean = w[2][j]
    run_std = w[3][j]
    w[2][j] = new_run_mean_value1
    w[3][j] = new_run_std_value2

model.get_layer('batchnorm_layer_name').set_weights(w)

答案 1 :(得分:0)

如果使用BatchNormalization图层,则无需手动更新移动均值和方差。 Keras负责在培训期间更新这些参数,并在测试期间保持固定(使用model.predictmodel.evaluate功能,与model.fit_generator和朋友相同。

Keras还会跟踪学习阶段,以便在培训和验证/测试期间运行不同的代码路径。

答案 2 :(得分:0)

对该问题有两种解释:第一种是假设目标是使用高级培训api,而Matias Valdenegro回答了这个问题。

第二点(如评论中所述)是是否可以通过此处keras a simplified tensorflow interface和“收集可训练的权重和状态更新”部分讨论的标准张量流优化器使用批处理规范化。如前所述,更新操作可在layer.updates中访问,而不能在$ prebuilts/sdk/tools/jack-admin kill-server $ prebuilts/sdk/tools/jack-admin start-server 中访问,实际上,如果您在tensorflow中具有keras模型,则可以使用标准tensorflow优化器和批处理规范化进行优化

tf.GraphKeys.UPDATE_OPS

,然后使用张量流会话获取train_op。为了区分批次归一化层的训练和评估模式,您需要输入 keras引擎的学习阶段状态(请参见上面给出的同一tutorial page上的“训练和测试期间的不同行为”)。例如,这样可以工作

update_ops  = model.updates
with tf.control_dependencies(update_ops):
     train_op = optimizer.minimize( loss )

我在tensorflow 1.12中尝试了此方法,它适用于包含批处理规范化的模型。鉴于我现有的tensorflow代码,并且鉴于即将接近tensorflow 2.0版,我很想亲自使用此方法,但是鉴于tensorflow文档中未提及该方法,因此我不确定长期是否会支持该方法,我最终决定不使用它,并投入更多的钱来更改代码以使用高级api。