Tensorflow`tf.layers.batch_normalization`不会将更新操作添加到`tf.GraphKeys.UPDATE_OPS`

时间:2018-02-19 21:38:20

标签: python tensorflow

以下代码(复制/粘贴可运行)说明使用tf.layers.batch_normalization

import tensorflow as tf
bn = tf.layers.batch_normalization(tf.constant([0.0]))
print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

> []     # UPDATE_OPS collection is empty

使用TF 1.5,文档(如下所述)明确指出 UPDATE_OPS不应该为空在这种情况下(https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization):

  

注意:训练时,moving_mean和moving_variance需要   更新。默认情况下,更新操作位于   tf.GraphKeys.UPDATE_OPS,因此需要将它们添加为依赖项   train_op。例如:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

1 个答案:

答案 0 :(得分:5)

只需将您的代码更改为处于培训模式(将training标记设置为True),如quote中所述:

  

注意:培训时,需要更新moving_mean和moving_variance。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。

 import tensorflow as tf
 bn = tf.layers.batch_normalization(tf.constant([0.0]), training=True)
 print(tf.get_collection(tf.GraphKeys.UPDATE_OPS))

将输出:

[< tf.Tensor 'batch_normalization/AssignMovingAvg:0' shape=(1,) dtype=float32_ref>, 
 < tf.Tensor 'batch_normalization/AssignMovingAvg_1:0' shape=(1,) dtype=float32_ref>]

和Gamma和Beta最终出现在TRAINABLE_VARIABLES集合中:

print(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

[<tf.Variable 'batch_normalization/gamma:0' shape=(1,) dtype=float32_ref>, 
 <tf.Variable 'batch_normalization/beta:0' shape=(1,) dtype=float32_ref>]