我正在使用Tensorflow 1.4.1并了解Tensorflow Dataset API。在描述consuming values from an iterator的部分中,有以下示例
bar.txt
......附带以下说明:
请注意,评估
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10])) dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100]))) dataset3 = tf.data.Dataset.zip((dataset1, dataset2)) iterator = dataset3.make_initializable_iterator() sess.run(iterator.initializer) next1, (next2, next3) = iterator.get_next()
,next1
或next2
中的任何一个都会推进 所有组件的迭代器。迭代器的典型消费者会 将所有组件包含在单个表达式中。
我决定通过以下简单示例测试此行为。
next3
如您所见,我正在使用import tensorflow as tf
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5)
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
next1, next2 = iterator.get_next()
A = next1
B = next1 + next2
while True:
try:
a, b = sess.run([A,B])
print(a,b)
except tf.errors.OutOfRangeError:
print('done')
break
和next1
两个表达式评估A
。根据上面的引用,如果迭代器确实是针对每个评估进行的,我期望迭代器在这两个评估中都有进展,并且看到包含
B
但是,我得到的是:
(0, 2)
(2, 6)
为什么迭代器只提前一次?什么是一个显示我期望看到的行为的工作示例?
答案 0 :(得分:4)
当您在TensorFlow图中有一个改变状态的操作(如iterator.get_next()
)时,通常会出现混淆。规则很简单:
图表中的每个有状态操作(不在
tf.while_loop()
或tf.cond()
中)每个Session.run()
调用只执行一次。
应用该规则,图中只有一个iterator.get_next()
op,因此迭代器只会在Session.run()
次调用时前进一次,并且相同的元素将用于计算{{1} }和A
。
要获得所需的行为,您需要创建第二个B
操作。此外,为了获得确定性行为,我们需要确保两个iterator.get_next()
操作之间存在控制依赖关系。例如,以下程序展示了您期望的行为:
iterator.get_next()