TL; DR:在tensorflow 0.1.4中使用Dataset api时,如何确保以多线程方式加载数据?
以前我用磁盘中的图像做了类似的事情:
filename_queue = tf.train.string_input_producer(filenames)
image_reader = tf.WholeFileReader()
_, image_file = image_reader.read(filename_queue)
imsize = 120
image = tf.image.decode_jpeg(image_file, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image_r = tf.image.resize_images(image, [imsize, imsize])
images = tf.train.shuffle_batch([image_r],
batch_size=20,
num_threads=30,
capacity=200,
min_after_dequeue=0)
这确保将有20个线程为下一次学习迭代准备好数据。
现在使用Dataset api我会做类似的事情:
dataset = tf.data.Dataset.from_tensor_slices((filenames, filenames_up, filenames_blacked))
dataset = dataset.map(parse_upscaler_corrector_batch)
在此之后我创建了一个迭代器:
sess = tf.Session();
iterator = dataset.make_initializable_iterator();
next_element = iterator.get_next();
sess.run(iterator.initializer);
value = sess.run(next_element)
将传递变量值以供进一步处理。
那么我如何确保以多线程方式准备数据呢?我在哪里可以阅读有关Dataset api和多线程数据的信息?
答案 0 :(得分:2)
所以看来实现这一目标的方法如下:
dataset = dataset.map(parse_upscaler_corrector_batch, num_parallel_calls=12).prefetch(32).batch(self.ex_config.batch_size)
如果更改num_parallel_calls = 12,则可以看到网络/硬盘负载和CPU加载峰值或减少。