我正在使用TensorFlow的数据集API,根据文档,我对shuffle()方法感到困惑:
Dataset.shuffle()转换使用与tf.RandomShuffleQueue类似的算法随机混洗输入数据集:它维护一个固定大小的缓冲区,并从该缓冲区中随机选择下一个元素。
如果我只是'部分'洗牌我的数据集(例如buffer_size< = no。的元素),我希望只有第一个 buffer_size 元素会被洗牌,但事实并非如此,见例子:
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7,8])
.shuffle(buffer_size=4, seed=42)
.batch(2)
iter = dataset.make_initializable_iterator() # create the iterator
el = iter.get_next()
with tf.Session() as sess:
sess.run(iter.initializer)
print('batch:', sess.run(el))
输出:
batch: [2 5]
为什么5在这里?因为缓冲区大小只有4?前2个元素应该在1~4之内吗?我在这里错过了什么?
由于
答案 0 :(得分:3)
简短的回答是,可以随时补充随机缓冲区,包括在创建批处理的过程中。
以下是您的观察结果: