Tensorflow的batch_norm中的模型变量

时间:2017-01-15 22:21:37

标签: tensorflow

在线文档说,moving_average和moving_variance都是model_variables,而tf.model_variables()则返回local_variables类型的张量。这是否意味着在保存状态时不会保存model_variables?

我试图将批量标准化应用于几个3D卷积和完全连接的层。我使用batch_norm训练我的网络并保存了一个检查点文件,但当我去恢复我保存的状态时,它说找不到moving_mean。确切的错误是,当TF将恢复的值分配给moving_mean时,lhs张量的形状[]无法与rhs的形状协调,[20]。

当我不在我的图层周围添加batch_norm时,图表会恢复正常。 我计划在训练结束时添加一个全局变量来保存我的moving_mean和moving_variance值。这是TF对我使用batch_norm的方式吗?

谢谢!

1 个答案:

答案 0 :(得分:1)

变量moving_mean和moving_variance不在我保存的说明中,因为我已将updates_collections设置为默认值。由于我在运行图层时从未包含控件依赖项,因此这些变量从未更新过。

要包含的代码是:

from tensorflow.python import control_flow_ops

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
    updates = tf.tuple(update_ops)
    total_loss = control_flow_ops.with_dependencies(updates, total_loss)

或设置

updates_collection=None 

进行就地更新。

有关详细信息,请参阅the API descriptioncurrent github discussion