使用Keras进行跨语言学习时,如何更新moving_mean / moving_variance

时间:2019-01-07 08:19:18

标签: keras

出于某种原因,我想使用Kearas预训练模型和张量流进行训练。 因为包含了BN,所以我们需要手动更新moving_mean / moving_variance。解决方案如下:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
         train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize

但是,当使用keras模型时,tf.get_collection(tf.GraphKeys.UPDATE_OPS)返回一个空列表。所以我就这样改变了:

update_ops = base_model.updates
...
with tf.control_dependencies(update_ops):
             train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize

然后,总是有一个错误:

  

您必须使用dtype输入占位符张量'input_1'的值   浮动和形状[?,224,224,3],似乎没有数据馈入   模型。当我用“ tf.control_dependencies(update_ops)删除”时,它   可行,结果当然是错误的,因为   moving_mean / moving_variance未更新

0 个答案:

没有答案