完成一个纪元后,Tensorflow Dataset API恢复Iterator

时间:2018-03-11 05:28:20

标签: tensorflow tensorflow-datasets

我有190个功能和标签,我的批量大小为20但是在9次迭代后tf.reshape返回异常重新整形的输入是一个具有21个值的张量,但请求的形状有60个我知道这是由于Iterator.get_next()。我如何恢复我的迭代器,以便它从一开始就会再次开始提供批次?

1 个答案:

答案 0 :(得分:4)

如果要从Dataset的开头重新启动tf.data.Iterator,请考虑使用可初始化的迭代器,该迭代器具有可以运行的操作以重新初始化迭代器:

dataset = ...  # A `tf.data.Dataset` instance.
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

train_op = ...  # Something that depends on `next_element`.

for _ in range(NUM_EPOCHS):
  # Initialize the iterator at the beginning of `dataset`.
  sess.run(iterator.initializer)

  # Loop over the examples in `iterator`, running `train_op`.
  try:
    while True:
      sess.run(train_op)

  except tf.errors.OutOfRangeError:  # Thrown at the end of the epoch.
    pass

  # Perform any per-epoch computations here.

有关不同类型Iterator的详细信息,请参阅the tf.data programmer's guide