Tensorflow Data API - 预取

时间:2017-11-01 22:31:03

标签: tensorflow prefetch tensorflow-datasets

我正在尝试使用TF的新功能,即Data API,我不确定预取的工作原理。在下面的代码中

def dataset_input_fn(...)
    dataset = tf.data.TFRecordDataset(filenames, compression_type="ZLIB")
    dataset = dataset.map(lambda x:parser(...))
    dataset = dataset.map(lambda x,y: image_augmentation(...)
                      , num_parallel_calls=num_threads
                     )

    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)    
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()

我放dataset=dataset.prefetch(batch_size)之上的每一行之间有关系吗?或者,如果数据集来自output_buffer_size,那么它应该在每次使用tf.contrib.data的操作之后?

1 个答案:

答案 0 :(得分:11)

github的讨论中,我发现了mrry的评论:

  

请注意,在TF 1.4中会有一个Dataset.prefetch()方法   这样可以更容易地在管道中的任何位置添加预取,而不是   就在map()之后。 (您可以通过下载当前每晚来尝试   建立。)

  

例如,Dataset.prefetch()将启动后台线程   填充有序缓冲区,其作用类似于tf.FIFOQueue,以便   下游管道阶段不需要阻止。但是,prefetch()   实现要简单得多,因为它不需要支持   许多不同的并发操作作为tf.FIFOQueue。

所以它意味着prefetch可以由任何命令放置,它适用于上一个命令。到目前为止,我已经注意到最大的性能提升仅仅是在最后。

还有一个关于Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle的讨论,其中mrry解释了有关预取和缓冲的更多信息。

更新2018/10/01

从版本1.7.0开始,Dataset API(在contrib中)有一个prefetch_to_device选项。请注意,此转换必须是管道中的最后一个,当TF 2.0到达时contrib将消失。要在多个GPU上进行预取,请使用MultiDeviceIterator(例如,请参阅#13610multi_device_iterator_ops.py

https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/prefetch_to_device