`tf.data.Dataset.repeat()`缓冲内存中的整个数据集吗?

时间:2017-11-05 02:05:43

标签: tensorflow

从TF文档中查看此代码示例:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()

dataset.repeat(num_epochs)是否要求将整个数据集加载到内存中?或者它是否在收到数据集结束异常时重新初始化之前的数据集?

关于这一点,文档含糊不清。

1 个答案:

答案 0 :(得分:0)

基于这个简单的测试,似乎repeat 缓冲数据集,它必须重新初始化上游数据集。

n = tf.data.Dataset.range(5).shuffle(buffer_size=5).repeat(2).make_one_shot_iterator().get_next()
[sess.run(n) for _ in range(10)]
Out[83]: [2, 0, 3, 1, 4, 3, 1, 0, 2, 4]

逻辑表明,如果repeat缓冲了它的输入,那么在这个简单的实验中就会重复相同的随机混洗模式。