数据集api中的多线程

时间:2017-12-05 12:28:01

标签: python-3.x tensorflow-gpu

TL; DR:在tensorflow 0.1.4中使用Dataset api时,如何确保以多线程方式加载数据?

以前我用磁盘中的图像做了类似的事情:

filename_queue = tf.train.string_input_producer(filenames)    
image_reader = tf.WholeFileReader()
_, image_file = image_reader.read(filename_queue)    
imsize = 120    
image = tf.image.decode_jpeg(image_file, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image_r = tf.image.resize_images(image, [imsize, imsize])    
images = tf.train.shuffle_batch([image_r],
    batch_size=20,
    num_threads=30,
    capacity=200,
    min_after_dequeue=0)

这确保将有20个线程为下一次学习迭代准备好数据。

现在使用Dataset api我会做类似的事情:

dataset = tf.data.Dataset.from_tensor_slices((filenames, filenames_up, filenames_blacked))
dataset = dataset.map(parse_upscaler_corrector_batch)

在此之后我创建了一个迭代器:

sess = tf.Session();
iterator = dataset.make_initializable_iterator();
next_element = iterator.get_next();
sess.run(iterator.initializer); 
value = sess.run(next_element)

将传递变量以供进一步处理。

那么我如何确保以多线程方式准备数据呢?我在哪里可以阅读有关Dataset api和多线程数据的信息?

1 个答案:

答案 0 :(得分:2)

所以看来实现这一目标的方法如下:

dataset = dataset.map(parse_upscaler_corrector_batch, num_parallel_calls=12).prefetch(32).batch(self.ex_config.batch_size)

如果更改num_parallel_calls = 12,则可以看到网络/硬盘负载和CPU加载峰值或减少。