在线文档说,moving_average和moving_variance都是model_variables,而tf.model_variables()则返回local_variables类型的张量。这是否意味着在保存状态时不会保存model_variables?
我试图将批量标准化应用于几个3D卷积和完全连接的层。我使用batch_norm训练我的网络并保存了一个检查点文件,但当我去恢复我保存的状态时,它说找不到moving_mean。确切的错误是,当TF将恢复的值分配给moving_mean时,lhs张量的形状[]无法与rhs的形状协调,[20]。
当我不在我的图层周围添加batch_norm时,图表会恢复正常。 我计划在训练结束时添加一个全局变量来保存我的moving_mean和moving_variance值。这是TF对我使用batch_norm的方式吗?
谢谢!
答案 0 :(得分:1)
变量moving_mean和moving_variance不在我保存的说明中,因为我已将updates_collections设置为默认值。由于我在运行图层时从未包含控件依赖项,因此这些变量从未更新过。
要包含的代码是:
from tensorflow.python import control_flow_ops
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if update_ops:
updates = tf.tuple(update_ops)
total_loss = control_flow_ops.with_dependencies(updates, total_loss)
或设置
updates_collection=None
进行就地更新。
有关详细信息,请参阅the API description和current github discussion。