带有tf.placeholder的多个输入的Tensorflow图用于验证

时间:2018-10-01 15:36:50

标签: python tensorflow tensorflow-datasets

我在模型上使用tf.data API。现在,我将tf.data迭代器的输出定义为网络的输入。摆脱了feed_dict方法之后,我的性能大大提高了。

现在,我想实现一个至少在每次训练后运行一次的验证集。是否可以为tf.data实现验证运行,还是必须设置一个占位符并手动切换tf.datasets并再次使用feed_dicts?推荐的验证测试方法是什么?

1 个答案:

答案 0 :(得分:1)

破解方式-节点替换

最简单的方法(虽然绝对不是最漂亮的)只是将tf.data API创建的节点用作feed_dict的输入-这是因为在Tensorflow中,您可以替换任何值将其值直接输入feed_dict。

所以这就像

batch_input = tf_train_data_foo()
validation_input = tf_validation_data_foo()

model = build_model(batch_input)
optimization_step = some_optimization_foo(model)

# Regular train
session.run(optimization_step)

# Validation run
validation_data = session.run(validation_input)
session.run(model, {batch_input: validation_data})

更好的方法-变量重用

如果所有构造都使用tf.get_variable而不是创建新变量,并且所有作用域都设置为可以获取现有变量,则只需调用build_model函数两次-一次使用训练数据(来自tf.data)和一次验证数据。您可以在this answer

上查看有关变量重用的更多详细信息