当与repeat()和batch()一起使用时,TensorFlow dataset.shuffle()行为

时间:2018-08-02 09:23:57

标签: tensorflow dataset

这到底会做什么?

dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset.shuffle(buffer_size=5).repeat().batch(3)

我注意到了几个相关的问题,但是没有一个问题能完全回答我的担忧。我对shuffle(buffer_size)在做什么感到困惑。我知道它将把5个第一个示例[0, 0, 0, 1, 1]存入内存,但是此缓冲区接下来将如何处理?以及该缓冲区如何与repeat()batch()相互作用?

1 个答案:

答案 0 :(得分:2)

洗牌的工作方式很复杂,但是您可以通过先填充一个大小为buffer_size的缓冲区,然后每次请求一个元素时,在该缓冲区中采样一个统一的随机位置并将其替换为一个新元素来假装它起作用

在改组之前进行批处理意味着您将对预先制成的微型批处理进行混洗(因此,微型批处理本身不会改变,只是它们的顺序而已),而在改组之后进行批处理可以让您随机更改批处理的内容。同样,在改组前重复意味着您将对无限的流示例进行改组(因此第二个时代将与第一个时代具有不同的顺序),而在改组之后重复意味着您将始终在每个时代看到相同的示例。