鉴于我有一个training
的张量变量,在训练时将其设置为true
,在测试时将其设置为false
。
tf.layers.batch_normalization(input_tensor, training=training) # training = tensor
tf.metrics.mean(loss, updates_collections=tf.GraphKeys.UPDATE_OPS)
除了批量标准化外,我还将自己的更新操作添加到UPDATE_OPS
。
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
def train_run(epoch):
sess.run([dataset.train_init_op, init_local]) # Sets training to true
for i in range(dataset.train_iterations_per_epoch):
sess.run([train_op, update_ops])
def test_run(epoch):
sess.run([dataset.test_init_op, init_local]) # Sets training to false
for i in range(dataset.test_iterations_per_epoch):
sess.run(update_ops)
UPDATE_OPS
中的test_run
仍包含所有批处理规范化操作,但training
为假。
是否仍通过运行批处理规范更新操作来进行更新?
还是我将可训练的var与指标ops混淆了,并且在测试期间是否对其进行更新并不重要?