具有部分shuffle的Tensorflow数据集

时间:2018-05-09 13:47:54

标签: python tensorflow

我正在使用TensorFlow的数据集API,根据文档,我对shuffle()方法感到困惑:

  

Dataset.shuffle()转换使用与tf.RandomShuffleQueue类似的算法随机混洗输入数据集:它维护一个固定大小的缓冲区,并从该缓冲区中随机选择下一个元素。

如果我只是'部分'洗牌我的数据集(例如buffer_size< = no。的元素),我希望只有第一个 buffer_size 元素会被洗牌,但事实并非如此,见例子:

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
                         .shuffle(buffer_size=4, seed=42)
                         .batch(2)
iter = dataset.make_initializable_iterator() # create the iterator
el = iter.get_next()
with tf.Session() as sess:
    sess.run(iter.initializer) 
    print('batch:', sess.run(el))

输出:

batch: [2 5]

为什么5在这里?因为缓冲区大小只有4?前2个元素应该在1~4之内吗?我在这里错过了什么?

由于

1 个答案:

答案 0 :(得分:3)

简短的回答是,可以随时补充随机缓冲区,包括在创建批处理的过程中。

以下是您的观察结果:

  • 数据集从您的数据中读取前4个元素。 shuffle缓冲区现在包含[1,2,3,4]
  • 您请求两个元素(通过数据集上的get_next()来创建2个批次)
  • shuffle数据集选择2并将下一个元素读入shuffle缓冲区,该缓冲区现在包含[1,3,4,5]。
  • shuffle数据集从缓冲区中选取5个。
  • 您的[2,5]批返回。