TF数据集迭代器只前进一次而不是两次

时间:2018-01-28 19:07:07

标签: python tensorflow tensorflow-datasets

我正在使用Tensorflow 1.4.1并了解Tensorflow Dataset API。在描述consuming values from an iterator的部分中,有以下示例

bar.txt

......附带以下说明:

  

请注意,评估dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100]))) dataset3 = tf.data.Dataset.zip((dataset1, dataset2)) iterator = dataset3.make_initializable_iterator() sess.run(iterator.initializer) next1, (next2, next3) = iterator.get_next() next1next2中的任何一个都会推进   所有组件的迭代器。迭代器的典型消费者会   将所有组件包含在单个表达式中。

我决定通过以下简单示例测试此行为。

next3

如您所见,我正在使用import tensorflow as tf dataset1 = tf.data.Dataset.range(5) dataset2 = tf.data.Dataset.range(5) dataset3 = tf.data.Dataset.zip((dataset1, dataset2)) iterator = dataset3.make_initializable_iterator() with tf.Session() as sess: sess.run(iterator.initializer) next1, next2 = iterator.get_next() A = next1 B = next1 + next2 while True: try: a, b = sess.run([A,B]) print(a,b) except tf.errors.OutOfRangeError: print('done') break next1两个表达式评估A。根据上面的引用,如果迭代器确实是针对每个评估进行的,我期望迭代器在这两个评估中都有进展,并且看到包含

的打印输出
B

但是,我得到的是:

(0, 2)
(2, 6)

为什么迭代器只提前一次?什么是一个显示我期望看到的行为的工作示例?

1 个答案:

答案 0 :(得分:4)

当您在TensorFlow图中有一个改变状态的操作(如iterator.get_next())时,通常会出现混淆。规则很简单:

  

图表中的每个有状态操作(不在tf.while_loop()tf.cond()中)每个Session.run()调用只执行一次。

应用该规则,图中只有一个iterator.get_next() op,因此迭代器只会在Session.run()次调用时前进一次,并且相同的元素将用于计算{{1} }和A

要获得所需的行为,您需要创建第二个B操作。此外,为了获得确定性行为,我们需要确保两个iterator.get_next()操作之间存在控制依赖关系。例如,以下程序展示了您期望的行为:

iterator.get_next()