这个想法是让整个数据集运行在多个纪元之上。显然,该数据集被重复使用。调用repeat()到底能做什么?
答案 0 :(得分:0)
tf.data.TFRecordDataset().repeat(num_epochs)
的全部目的是在内存中重复数据集num_epoch
次,以便您可以使用num_epoch
对数据集进行tf.data.TFRecordDataset().make_one_shot_iterator()
次迭代。请检查下面的示例(请注意,我正在使用tf.data.Dataset()
,因为它是为了演示.repeat(num_epochs)
的用法)。
X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.batch(5)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
sess.run(iterator.get_next())
sess.run(iterator.get_next())
# The third time this is called it throws an exception
sess.run(iterator.get_next())
现在,如果我们进行tf.data.Dataset().repeat(num_epochs)
,则可以迭代数据集num_epochs
次。
X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.repeat(2)
dataset = dataset.batch(5)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
sess.run(iterator.get_next())
sess.run(iterator.get_next())
sess.run(iterator.get_next())
sess.run(iterator.get_next())
# Now, the fifth time it will throw an exception
sess.run(iterator.get_next())
但是,我建议避开tf.data.Dataset().make_one_shot_iterator()
。事实是,如果要对数据集进行一次或有限次数的迭代,则此迭代器非常有用。因此,我建议使用tf.data.Dataset().make_initializable_iterator()
。
X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.batch(5)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
# Initialize the iterator to iterate over the dataset
sess.run(iterator.initializer)
sess.run(iterator.get_next())
sess.run(iterator.get_next())
# Once done, initialize it again
sess.run(iterator.initializer)
sess.run(iterator.get_next())
sess.run(iterator.get_next())
如您所见,使用可初始化的迭代器,您可以遍历数据集,完成后,再次将其重新初始化以开始新的纪元。这是一种方便的方法。
num_epochs = 3
X = np.random.rand(10,2)
dataset = tf.data.Dataset.from_tensor_slices(X)
dataset = dataset.batch(5)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
for e in range(num_epochs):
sess.run(iterator.initializer)
try:
while True:
sess.run(iterator.get_next())
# Or do whatever you want with the batch
except tf.errors.OutOfRangeError:
print(f"Epoch {e+1} finished, starting over!")
pass