更新batch_normalization意味着&使用Estimator API进行差异

时间:2018-03-10 02:41:46

标签: tensorflow machine-learning batch-normalization tensorflow-estimator

文档并非100%明确:

  

注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。例如:

(见https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization

这是否意味着保存moving_meanmoving_variance所需的全部内容如下所示?

def model_fn(features, labels, mode, params):
   training = mode == tf.estimator.ModeKeys.TRAIN
   extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

   x = tf.reshape(features, [-1, 64, 64, 3])
   x = tf.layers.batch_normalization(x, training=training)

   # ...

  with tf.control_dependencies(extra_update_ops):
     train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())

换句话说,只需使用

with tf.control_dependencies(extra_update_ops):

注意保存moving_meanmoving_variance

2 个答案:

答案 0 :(得分:1)

是的,添加这些控件依赖项将保存均值和方差。

答案 1 :(得分:1)

事实证明,这些值可以自动保存。边缘情况是,如果在将批量标准化操作添加到图形之前获得更新操作集合,则更新集合将为空。以前没有记录,但现在。

使用batch_norm时的警告是在您调用tf.get_collection(tf.GraphKeys.UPDATE_OPS)后致电tf.layers.batch_normalization