我在模型上使用tf.data
API。现在,我将tf.data
迭代器的输出定义为网络的输入。摆脱了feed_dict
方法之后,我的性能大大提高了。
现在,我想实现一个至少在每次训练后运行一次的验证集。是否可以为tf.data
实现验证运行,还是必须设置一个占位符并手动切换tf.datasets并再次使用feed_dicts
?推荐的验证测试方法是什么?
答案 0 :(得分:1)
最简单的方法(虽然绝对不是最漂亮的)只是将tf.data
API创建的节点用作feed_dict的输入-这是因为在Tensorflow中,您可以替换任何值将其值直接输入feed_dict。
所以这就像
batch_input = tf_train_data_foo()
validation_input = tf_validation_data_foo()
model = build_model(batch_input)
optimization_step = some_optimization_foo(model)
# Regular train
session.run(optimization_step)
# Validation run
validation_data = session.run(validation_input)
session.run(model, {batch_input: validation_data})
如果所有构造都使用tf.get_variable
而不是创建新变量,并且所有作用域都设置为可以获取现有变量,则只需调用build_model
函数两次-一次使用训练数据(来自tf.data
)和一次验证数据。您可以在this answer