为什么需要调用TFRecordDataset.repeat()?

时间:2019-08-11 03:05:44

标签: tensorflow

这个想法是让整个数据集运行在多个纪元之上。显然,该数据集被重复使用。调用repeat()到底能做什么?

1 个答案:

答案 0 :(得分:0)

tf.data.TFRecordDataset().repeat(num_epochs)的全部目的是在内存中重复数据集num_epoch次,以便您可以使用num_epoch对数据集进行tf.data.TFRecordDataset().make_one_shot_iterator()次迭代。请检查下面的示例(请注意,我正在使用tf.data.Dataset(),因为它是为了演示.repeat(num_epochs)的用法)。

X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.batch(5)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
    sess.run(iterator.get_next())
    sess.run(iterator.get_next())
    # The third time this is called it throws an exception
    sess.run(iterator.get_next())

现在,如果我们进行tf.data.Dataset().repeat(num_epochs),则可以迭代数据集num_epochs次。

X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.repeat(2)
dataset = dataset.batch(5)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
    sess.run(iterator.get_next())
    sess.run(iterator.get_next())
    sess.run(iterator.get_next())
    sess.run(iterator.get_next())
    # Now, the fifth time it will throw an exception
    sess.run(iterator.get_next())

但是,我建议避开tf.data.Dataset().make_one_shot_iterator()。事实是,如果要对数据集进行一次或有限次数的迭代,则此迭代器非常有用。因此,我建议使用tf.data.Dataset().make_initializable_iterator()

X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.batch(5)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
    # Initialize the iterator to iterate over the dataset
    sess.run(iterator.initializer)
    sess.run(iterator.get_next())
    sess.run(iterator.get_next())
    # Once done, initialize it again
    sess.run(iterator.initializer)
    sess.run(iterator.get_next())
    sess.run(iterator.get_next())

如您所见,使用可初始化的迭代器,您可以遍历数据集,完成后,再次将其重新初始化以开始新的纪元。这是一种方便的方法。

num_epochs = 3
X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.batch(5)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
    for e in range(num_epochs):
        sess.run(iterator.initializer)
        try:
            while True:
                sess.run(iterator.get_next()) 
                # Or do whatever you want with the batch
        except tf.errors.OutOfRangeError:
            print(f"Epoch {e+1} finished, starting over!")
            pass