Tensorflow数据输入切换:训练/验证

时间:2016-09-21 13:17:11

标签: python tensorflow

在我从方便但速度较低的占位符切换后,我的数据通过队列运行程序进入我的图表。

在每个训练时期之后,我希望运行验证通行证。除了训练传递之外,验证传递使用不同的数据,没有增加,也没有改组。

问题很简单:如何切换这些东西?

一些观察结果:

  • 我无法通过shuffle布尔值切换string_input_producer中的tf.placeholder选项。
  • 我发现的唯一在线示例使用placeholder从验证数据中分离培训。反过来,不要使用优秀的队列跑步者。
  • 我确实设法使用tf.cond()执行上述操作,我将测试我通过is_training的{​​{1}} tf.placeholder布尔值。这个解决方案是最优化的吗?这个feed_dict方法有多贵?

2 个答案:

答案 0 :(得分:3)

适用于我的方法是使用tf.placeholder_with_default

images_train, labels_train = train_data_pipeline(fnlist_train, ref_grid)
images_val, labels_val = val_data_pipeline(fnlist_val, ref_grid)
images = tf.placeholder_with_default(images_train, shape=[None, FLAGS.nx_image, FLAGS.ny_image, FLAGS.nz_image])
labels = tf.placeholder_with_default(labels_train, shape=[None, label_length])

在培训期间,imageslabels直接来自培训队列。对于间歇性验证步骤,我通过调用images通过feed_dict提供labelssess.run()。唯一的小问题是,验证数据也是队列中的张量,feed_dict不接受张量,所以我先调用sess.run([images_val, labels_val])来获取numpy值,然后在feed_dict中使用它们。似乎运行良好,并且张量==> numpy ==>张量转换的延迟最小,这只会在验证过程中发生。

对于验证数据具有单独处理要求的情况,可以在设置单独的验证队列并处理流程时处理此问题。

答案 1 :(得分:1)

一个可能的答案是使用make_template 这在https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/kernel_tests/template_test.py中概述;它基本上说可以做到这一点:

training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])

tf.set_random_seed(1234)

def test_line(x):
  m = tf.get_variable("w", shape=[],
                      initializer=tf.truncated_normal_initializer())
  b = tf.get_variable("b", shape=[],
                      initializer=tf.truncated_normal_initializer())
  return x * m + b

line_template = template.make_template("line", test_line)

train_prediction = line_template(training_input)
test_prediction = line_template(test_input)

train_loss = tf.reduce_mean(tf.square(train_prediction - training_output))
test_loss = tf.reduce_mean(tf.square(test_prediction - test_output))

optimizer = tf.train.GradientDescentOptimizer(0.1)
train_op = optimizer.minimize(train_loss)

with tf.Session() as sess:
  sess.run(tf.initialize_all_variables())
  initial_test_loss = sess.run(test_loss)
  sess.run(train_op)
  final_test_loss = sess.run(test_loss)

# Parameters are tied, so the loss should have gone down when we trained it.
self.assertLess(final_test_loss, initial_test_loss)