如何理解tf队列中的多线程?

时间:2017-08-01 09:08:56

标签: tensorflow

Tensorflow为我们提供了两种实现阅读数据的方法。

第一种方式,使用许多读者,例如host='localhost',每个线程一个读者。

第二种方式,只使用一个阅读器和多队列的入队操作,例如我们可以使用tf.TextLineReader并将tf.train.shuffle_batch设置为大于1。

我无法理解第二种方式,我们只有一个读者加载数据(也许是一个线程),为什么我们需要这么多线程来排队?

首先,我们应该使用num_threads,我们可以设置tf.train.shuffle_batch_join参数,所以我认为第一种方式是可以理解的。

任何人都可以给我一些解释,为什么我们需要一个读者但需要很多线程才能入队?

1 个答案:

答案 0 :(得分:0)

根据我的理解,它们仅在输入类型方面有所不同。 tf.train.batch_join接受样本列表,而tf.train.batch一次只接收一个样本。

tf.train.batch_join的情况下,预期队列的输入样本将由一个或多个线程提供。因此,通过输入中的样本数来隐含地控制线程数。另一方面,tf.train.batch具有num_threads参数,以便能够创建多个排队操作。第一个参数tensorshttps://www.tensorflow.org/api_docs/python/tf/train/batch)由多个线程初始化。

使用多个enqueue线程可以隐藏延迟。考虑一个简单的场景,其中排队样本负责的读取器/功能/操作需要大量时间。在这种情况下,拥有多个线程可以节省生命。以下代码段给出了一个示例:

import tensorflow as tf
import numpy as np
import time

num_steps = 50
batch_size = 2
input_shape, target_shape = (2, 2), ()
num_threads = 4
queue_capacity = 10

get_random_data_sample函数随机生成(输入,目标)对,并具有一些随机延迟。

def get_random_data_sample():
    # Random inputs and targets
    np_input = np.float32(np.random.normal(0,1, (2,2)))
    np_target = np.int32(1)

    # Sleep randomly between 1 and 3 seconds.
    time.sleep(np.random.randint(1,3,1)[0])

    return np_input, np_target

# Wraps a python function and uses it as a TensorFlow op.
tensorflow_input, tensorflow_target = tf.py_func(get_random_data_sample, [], [tf.float32, tf.int32])

请注意,[tensorflow_input, tensorflow_target]操作(即读者)由num_threads多个线程运行。您可以通过更改num_threads并观察运行时间来验证这一点。

### tf.train.batch ###
batch_inputs, batch_targets = tf.train.batch([tensorflow_input, tensorflow_target], 
                                             batch_size=batch_size, 
                                             num_threads=num_threads, 
                                             shapes=[input_shape, target_shape], 
                                             capacity=queue_capacity)

sess = tf.InteractiveSession()
tf.train.start_queue_runners()
start_t = time.time()
for i in range(num_steps):
    numpy_inputs, numpy_targets = sess.run([batch_inputs, batch_targets])
print(time.time()-start_t)

inputs在以下示例中包含num_threads个元素。

### tf.train.batch_join ###
inputs = [[tensorflow_input, tensorflow_target] for _ in range(num_threads)]
join_batch_inputs, join_batch_targets = tf.train.batch_join(inputs,
                                                    shapes=[input_shape, target_shape],
                                                    batch_size=batch_size,
                                                    capacity=queue_capacity,)
sess = tf.InteractiveSession()
tf.train.start_queue_runners()
start_t = time.time()
for i in range(num_steps):
    numpy_inputs, numpy_targets = sess.run([join_batch_inputs, join_batch_targets])
print(time.time()-start_t)

tf.train.batch中的多个排队线程可以有效地隐藏延迟。您可以使用num_threadsqueue_capacity进行游戏。两项业务的表现非常相似。