如何从tensorflow数据集迭代器返回两次相同的批处理?

时间:2018-03-19 08:41:20

标签: python tensorflow tensorflow-datasets

我正在转换一些遗留代码以使用数据集API - 此代码使用feed_dict将一个批次提供给列车操作(实际上是三次),然后使用同一批次重新计算显示的损失。所以我需要一个迭代器,它返回完全相同的批次两(或几次)。不幸的是,我似乎找不到使用张量流数据集的方法 - 是否可能?

1 个答案:

答案 0 :(得分:7)

您可以一起使用Dataset.flat_map()Dataset.from_tensors()Dataset.repeat()重复Dataset的各个元素。例如,要重复两次元素:

NUM_REPEATS = 2
dataset = tf.data.Dataset.range(10)  # ...or the output of `.batch()`, etc.

# Repeat each element of `dataset` NUM_REPEATS times.
dataset = dataset.flat_map(
    lambda x: tf.data.Dataset.from_tensors(x).repeat(NUM_REPEATS))