循环中的tf.train.string_input_producer行为

时间:2017-01-28 12:42:21

标签: tensorflow

以下摘录摘自TensorFlow 0.12 API文档

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)
  example, label = read_my_file_format(filename_queue)
  # min_after_dequeue defines how big a buffer we will randomly sample
  #   from -- bigger means better shuffling but slower start up and more
  #   memory used.
  # capacity must be larger than min_after_dequeue and the amount larger
  #   determines the maximum we will prefetch.  Recommendation:
  #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch

对于普通的TensorFlow用户,我的问题可能非常基本,但我绝对是初学者。问题如下:

  • tf.train.string_input_producer创建一个用于保存文件名的队列。由于在训练期间一遍又一遍地调用input_pipeline(),如何确保每次使用相同的队列?我想,重要的是,如果对input_pipeline()的不同调用导致新队列的创建,似乎没有办法确保每次都选择不同的图像,并且纪元计数器和改组可以正确保持。

1 个答案:

答案 0 :(得分:3)

input_pipeline函数仅创建负责生成批量数据的(通常较大的)图形的一部分。如果你打算input_pipeline两次 - 无论出于何种原因 - 你确实会创建两个不同的队列。

通常,函数tf.train.string_input_producer实际上在当前活动图中创建了一个队列节点(或操作)(除非您指定,否则这是默认图形)有些不同)。 read_my_file_format然后从该队列中读取并依次生成单个"示例"张量,而tf.train.shuffle_batch然后将它们分成长度为batch_size的捆绑包。

然而,tf.train.shuffle_batch这两个由input_pipeline函数返回的张量的输出只有在会话下进行评估时才真正具有(新)值。如果多次评估这些张量,它们将包含不同的值 - 通过read_my_file_format从输入队列中列出的文件中获取。

这样想:

X_batch, Y_batch = input_pipeline(..., batch_size=100)

with tf.Session() as sess:
   sess.run(tf.global_variable_initializer())   
   tf.train.start_queue_runners()

   # get the first 100 examples and labels
   X1, Y1 = sess.run((X_batch, Y_batch))

   # get the next 100 examples and labels
   X2, Y2 = sess.run((X_batch, Y_batch))

   # etc.

使其运行的样板代码有点复杂,例如因为队列需要在图表中实际启动和停止,因为当它们干涸时会抛出tf.errors.OutOfRangeError等等。 更完整的示例可能如下所示:

with tf.Graph().as_default() as graph:
   X_batch, Y_batch = input_pipeline(..., batch_size=100)

   prediction = inference(X_batch)
   optimizer, loss = optimize(prediction, Y_batch)

coord = tf.train.Coordinator()
with tf.Session(graph=graph) as sess:
   init = tf.group(tf.local_variable_initializer(),
                   tf.global_variable_initializer())
   sess.run(init)

   # start the queue runners
   threads = tf.train.start_queue_runners(coord=coord)

   try:
       while not coord.should_stop():

           # now you're really indirectly querying the
           # queue; each iteration will see a new batch of
           # at most 100 values.
           _, loss = sess.run((optimizer, loss))

           # you might also want to do something with
           # the network's output - again, this would
           # use a fresh batch of inputs
           some_predicted_values = sess.run(prediction)

   except tf.errors.OutOfRangeError:
       print('Training stopped, input queue is empty.')
   finally:
       coord.request_stop()

   # stop the queue(s)
   coord.request_stop()
   coord.join(threads)

为了更深入地了解,您可能需要查看Reading data文档。