我有一个训练模型,可以获取所有训练数据并创建一个队列:
x = tf.placeholder(tf.float32, (N, steps, size), name='x')
y = tf.placeholder(tf.float32, (N, out_size), name='y')
var_x = tf.Variable(x, trainable=False, collections=[])
var_y = tf.Variable(y, trainable=False, collections=[])
x_queue, y_queue = tf.train.slice_input_producer([var_x, var_y],
num_epochs=10, shuffle=True)
x_batch, y_batch = tf.train.batch([x_queue, y_queue], batch_size=batch_size)
...
with tf.Session() as sess:
sess.run(var_x, feed_dict={x: X})
sess.run(var_y, feed_dict={y: Y})
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
...
这个网络工作正常,我能够训练它。 在这个网络中,我想添加一个新的占位符来获取我的测试数据:
x_test = tf.placeholder(tf.float32, (1, steps, size), name='x_test')
我想使用tf.cond
来控制哪个占位符被送入:
rnn_inputs = tf.cond(is_train, lambda: x, lambda: x_test)
然而,很多帖子都说使用tf.cond
效率不高。此外,使用新的占位符来测试/验证数据是一个问题,因为即使我正在尝试训练模型,tensorflow也会抛出一个错误,要求我将数据输入其中。
有没有标准的方法呢?
答案 0 :(得分:1)
最有效的方法是使用迭代器来提供数据。您可以创建一个句柄来指定是从列车还是验证数据集提供。以下是https://www.tensorflow.org/programmers_guide/datasets的示例。我发现这种方法很有效
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)
# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the output_types and output_shapes properties of either
# training_dataset or validation_dataset here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()
# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()
# The Iterator.string_handle() method returns a tensor that can be evaluated
# and used to feed the handle placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
# Loop forever, alternating between training and validation.
while True:
# Run 200 steps using the training dataset. Note that the training dataset is
# infinite, and we resume from where we left off in the previous `while` loop
# iteration.
for _ in range(200):
sess.run(next_element, feed_dict={handle: training_handle})
# Run one pass over the validation dataset.
sess.run(validation_iterator.initializer)
for _ in range(50):
sess.run(next_element, feed_dict={handle: validation_handle})