TensorFlow TFRecordDataset shuffle buffer_size行为

时间:2018-02-15 22:09:54

标签: python tensorflow tensorflow-datasets

我不清楚tf.TFRecordDataset中的buffer_size参数是做什么的。我们假设我们有以下代码:

dataset = dataset.shuffle(buffer_size=10000).repeat().batch(batch_size)

这是否意味着只会使用前10k个样本并永久重复,还是会遍历整个数据集?如果不是,它到底是什么?这个代码怎么样?

dataset = dataset.repeat().shuffle(buffer_size=10000).batch(batch_size)

我注意到this post,但它没有对buffer_size说些什么。

1 个答案:

答案 0 :(得分:2)

answer可能有助于更好地理解buffer_size方法的shuffle参数。

简而言之,数据集在其缓冲区中始终具有多个buffer_size个元素,并且每次添加元素时都会对此缓冲区进行洗牌。

因此,缓冲区大小为1就像没有改组一样,拥有数据集长度的缓冲区就像传统的混乱一样。

要了解改组和重复数据集之间的正确顺序,请查看官方performance guide

最佳做法通常是随机播放然后重复,因为这将确保您在每个时期都能看到整个数据集。