文档并非100%明确:
注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作位于tf.GraphKeys.UPDATE_OPS中,因此需要将它们作为依赖项添加到train_op。例如:
(见https://www.tensorflow.org/api_docs/python/tf/layers/batch_normalization)
这是否意味着保存moving_mean
和moving_variance
所需的全部内容如下所示?
def model_fn(features, labels, mode, params):
training = mode == tf.estimator.ModeKeys.TRAIN
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
x = tf.reshape(features, [-1, 64, 64, 3])
x = tf.layers.batch_normalization(x, training=training)
# ...
with tf.control_dependencies(extra_update_ops):
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
换句话说,只需使用
with tf.control_dependencies(extra_update_ops):
注意保存moving_mean
和moving_variance
?
答案 0 :(得分:1)
是的,添加这些控件依赖项将保存均值和方差。
答案 1 :(得分:1)
事实证明,这些值可以自动保存。边缘情况是,如果在将批量标准化操作添加到图形之前获得更新操作集合,则更新集合将为空。以前没有记录,但现在。
使用batch_norm时的警告是在您调用tf.get_collection(tf.GraphKeys.UPDATE_OPS)
后致电tf.layers.batch_normalization
。