在测试期间运行时批处理规范化UPDATE_OPS的行为

时间:2019-02-08 21:57:33

标签: tensorflow batch-normalization

鉴于我有一个training的张量变量,在训练时将其设置为true,在测试时将其设置为false

tf.layers.batch_normalization(input_tensor, training=training)  # training = tensor
tf.metrics.mean(loss, updates_collections=tf.GraphKeys.UPDATE_OPS) 

除了批量标准化外,我还将自己的更新操作添加到UPDATE_OPS

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

def train_run(epoch):
    sess.run([dataset.train_init_op, init_local])  # Sets training to true
    for i in range(dataset.train_iterations_per_epoch):
        sess.run([train_op, update_ops])

def test_run(epoch):
    sess.run([dataset.test_init_op, init_local])  # Sets training to false
    for i in range(dataset.test_iterations_per_epoch):
        sess.run(update_ops)

UPDATE_OPS中的test_run仍包含所有批处理规范化操作,但training为假。

是否仍通过运行批处理规范更新操作来进行更新?

还是我将可训练的var与指标ops混淆了,并且在测试期间是否对其进行更新并不重要?

0 个答案:

没有答案