可重新初始化的迭代器可以解决什么问题?

时间:2018-12-02 18:46:36

标签: tensorflow-datasets

来自tf.data documentation

  

可重新初始化的迭代器可以从多个不同的对象中初始化   数据集对象。例如,您可能有一个训练输入管道   使用随机扰动对输入图像进行改进   概括,以及用于评估的验证输入管道   未修改数据的预测。这些管道通常会使用   具有相同结构(即,相同   类型和每个组件的兼容形状。

给出了以下示例:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,
                                           training_dataset.output_shapes)
next_element = iterator.get_next()

training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

# Run 20 epochs in which the training dataset is traversed, followed by the
# validation dataset.
for _ in range(20):
  # Initialize an iterator over the training dataset.
  sess.run(training_init_op)
  for _ in range(100):
    sess.run(next_element)

  # Initialize an iterator over the validation dataset.
  sess.run(validation_init_op)
  for _ in range(50):
    sess.run(next_element)

目前尚不清楚这种复杂性的好处是什么。
为什么不简单地创建2个不同的迭代器?

1 个答案:

答案 0 :(得分:2)

重新初始化迭代器的原始动机如下:

  1. 用户的输入数据位于两个或更多具有相同结构但管道定义不同的tf.data.Dataset对象中。

    例如,您可能有一个Dataset.map()中带有增强功能的训练数据管道,以及一个生成原始示例的评估数据管道,但是它们都将生成具有相同结构的批次(就数量而言)。张量,它们的元素类型,形状等)。

  2. 用户将定义一个单独的训练图,该图从tf.data.Iterator中输入,该输入是使用Iterator.from_structure()创建的。

  3. 然后,用户可以通过重新初始化来自一个数据集的迭代器,在不同的输入数据源之间进行切换。

事后看来,可重新初始化的迭代器很难用于其预期目的。在TensorFlow 2.0(或启用了急切执行的1.x)中,使用惯用的Python for循环和高级培训API在不同的数据集上创建迭代器要容易得多:

tf.enable_eager_execution()

model = ...  # A `tf.keras.Model`, or some other class exposing `fit()` and `evaluate()` methods.

train_data = ...  # A `tf.data.Dataset`.
eval_data = ...   # A `tf.data.Dataset`.

for i in range(NUM_EPOCHS):
  model.fit(train_data, ...)

  # Evaluate every 5 epochs.
  if i % 5 == 0: 
    model.evaluate(eval_data, ...)