从TF文档中查看此代码示例:
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
dataset.repeat(num_epochs)
是否要求将整个数据集加载到内存中?或者它是否在收到数据集结束异常时重新初始化之前的数据集?
关于这一点,文档含糊不清。
答案 0 :(得分:0)
基于这个简单的测试,似乎repeat
不缓冲数据集,它必须重新初始化上游数据集。
n = tf.data.Dataset.range(5).shuffle(buffer_size=5).repeat(2).make_one_shot_iterator().get_next()
[sess.run(n) for _ in range(10)]
Out[83]: [2, 0, 3, 1, 4, 3, 1, 0, 2, 4]
逻辑表明,如果repeat
缓冲了它的输入,那么在这个简单的实验中就会重复相同的随机混洗模式。