Tensorflow中自定义估算器中的批量标准化

时间:2017-07-25 09:42:28

标签: tensorflow tensorboard

我指的是tf.layers.batch_normilization处的注释:

  

注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。例如:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

如何在Custom Estimator中实现这一点?例如,在Tensorflow的网站上查看此示例:The complete abalone model_fn

2 个答案:

答案 0 :(得分:2)

关于以下问题,在最底部有一个示例 https://github.com/tensorflow/tensorflow/issues/16455

if mode == tf.estimator.ModeKeys.TRAIN:
    lr = 0.001
    optimizer = tf.train.RMSPropOptimizer(learning_rate=lr, decay=0.9)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op)

答案 1 :(得分:0)

我猜你可以传递train_op,你可以参考EstimatorSpec的train_op参数。