在tensorflow的输入管道中进行线程处理

时间:2017-03-25 10:03:37

标签: multithreading machine-learning tensorflow neural-network

背景

tensorflow中的典型输入管道如下所示:

                  tf.train.string_input_producer(list_of_filenames)
                         (creates queue of filenames)
                                   |
                                  \|/
           fixed length reader reads records from the files
                                   |
                                  \|/
    Read records are decoded and processed(eg if dealing with images then cropping,flipping etc)
                                   |
                                  \|/
            tf.train.shuffle_batch(tensors,num_threads)
        (creates a shuffling queue and returns batches of tensors) 

问题

Q1)函数tf.train.string_input_producer()中没有num_threads的参数。这是否意味着只有一个线程专门用于从文件名队列中读取文件名?

Q2)函数tf.train.shuffle_batch()的num_threads参数的范围是什么,即此处提到的线程数用于读取,解码和处理文件,或者它们仅用于创建批量的张量?

Q3)有没有办法打印哪个线程从特定文件读取文件名或记录,即每个线程完成的工作记录?

1 个答案:

答案 0 :(得分:4)

所有数据加载操作都在张量流图中执行,您要做的是启动一个或多个线程来迭代读取器/入队操作。 Tensorflow提供了一个完全相同的QueueRunner类。 Coordinator类允许您非常简单地管理线程。

https://www.tensorflow.org/programmers_guide/threading_and_queues

这是上面链接的示例代码:

# Create a queue runner that will run 4 threads in parallel to enqueue
# examples.
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

# Launch the graph.
sess = tf.Session()
# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# Run the training loop, controlling termination with the coordinator.
for step in xrange(1000000):
    if coord.should_stop():
        break
    sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(enqueue_threads)

如果您在图表之外加载/预处理样本(在您自己的代码中,不使用TF操作),那么您将不会使用QueueRunner,而是使用您自己的类来使用{{将数据排入队列1}}循环中的命令。

Q1:处理的线程数为:sess.run(enqueue_op, feed_dict={...})

Q2:TF会话是线程安全的,每次调用qr.create_threads(sess, coord=coord, start=True)都会看到当前变量开始时的一致快照。您的QueueRunner入队操作可以运行任意数量的线程。他们将以线程安全的方式排队。

问题3:我自己还没有使用tf.run(...),但我认为您必须在图表中稍后请求张量tf.train.string_input_producer数据,只需将张量添加到您在dequeued

中的请求列表