Tensorflow:如何使用" new"带QueueRunner的数据集API

时间:2017-11-06 08:24:59

标签: python tensorflow dataset

基本上我有一个要处理的图像列表。 我需要在加载后进行一些预处理(数据增强),然后输入TF的主图。 目前我正在使用一个定制的生成器,它采用一系列路径产生一对张量(图像)并通过占位符提供给网络。每批次的顺序处理耗时约0.5秒。

我刚刚使用Dataset函数阅读了我可以直接使用的.from_generator() API,我可以直接使用.get_next()作为输入。

但是QueueRunner如何适应框架? Dataset隐式使用queue + dequeue来维护其generator/get_next管道,还是要求我之后明确地输入FIFOQueue?如果答案是后者,那么维持管道训练+验证多个random_shuffle时期的最佳做法是什么? (我的意思是,我需要维护多少DS/queueRunner,我在哪里设置随机播放和时代?)

1 个答案:

答案 0 :(得分:1)

如果您使用的是数据集API,则不必使用QueueRunner来拥有队列/缓冲区。可以使用数据集API创建队列/缓冲区,并预处理数据并同时训练网络。如果您有数据集,则可以使用prefetch functionshuffle function创建队列/缓冲区。

有关详情,请参阅official tutorial on the Dataset API

以下是在CPU上使用带预处理的预取缓冲区的示例:

 NUM_THREADS = 8
 BUFFER_SIZE = 100

 data = ...
 labels = ...
 inputs = (data, labels)

 def pre_processing(data_, labels_):
     with tf.device("/cpu:0"):
         # do some pre-processing here
         return data_, labels_

 dataset_source = tf.data.Dataset.from_tensor_slices(inputs)
 dataset = dataset_source.map(pre_processing, num_parallel_calls=NUM_THREADS)

 dataset = dataset.repeat(1)  # repeats for one epoch
 dataset = dataset.prefetch(BUFFER_SIZE)

 iterator = tf.data.Iterator.from_structure(dataset.output_types,
                                            dataset.output_shapes)
 next_element = iterator.get_next()
 init_op = iterator.make_initializer(dataset)

 with tf.Session() as sess:
     sess.run(init_op)
     while True:
         try:
             sess.run(next_element)
         except tf.errors.OutOfRangeError:
             break