Tensorflow Dataset API改组将性能降低9倍

时间:2018-12-02 23:31:59

标签: python performance tensorflow tensorflow-datasets

我正在使用Tensorflow Dataset API来获取一堆文件名;改组文件名;执行python函数以加载图像文件,对其进行预处理并将其转换为张量;然后缓存,重复和批处理它们。到目前为止,一切都很好。

当我在张量中添加shuffle()时,性能将降低9倍。同样,当我做self.dataset.apply(tf.data.experimental.shuffle_and_repeat(16384))时。

shuffle为什么会严重损害性能,我该如何解决?

代码:

filenames = tf.data.Dataset.list_files(self.FILE_PATTERN).shuffle(buffer_size=16384)
dataset = filenames.map(lambda filename: self.pp(filename), 
num_parallel_calls=self.N_CPUS)
dataset = dataset.cache("./cachefile")
# The line below (shuffle_and_repeat) made performance very bad (1s/step without, 9s/step with)
# dataset = dataset.apply(tf.data.experimental.shuffle_and_repeat(16384))
# This too:
# dataset = dataset.repeat().shuffle(16384)
# This works fine, but doesn't shuffle:
dataset = dataset.repeat()
dataset = dataset.batch(self.BATCH_SIZE)
dataset = dataset.prefetch(4)

1 个答案:

答案 0 :(得分:1)

尝试更改预取参数buffer_size = 2

数据集= dataset.prefetch(2)

prefetch是一个性能标志,在后台读取下一个数据集以进行下一次迭代。如果预取的buffer_size很大,那么它将创建大量的数据集进行迭代,并且由于内存不足而可能会变慢。