Tensorflow:sess.run中带有列表的变量赋值何时完成?

时间:2016-12-22 17:26:13

标签: python variables tensorflow variable-assignment timing

我认为变量赋值是在给sess.run的列表中的所有操作之后完成的,但是下面的代码在不同的执行时返回不同的结果。它似乎在列表中随机运行操作,并在列表中运行操作后分配变量。

a = tf.Variable(0)
b = tf.Variable(1)
c = tf.Variable(1)
update_a = tf.assign(a, b + c)
update_b = tf.assign(b, c + a)
update_c = tf.assign(c, a + b)

with tf.Session() as sess:
  sess.run(initialize_all_variables)
  for i in range(5):
    a_, b_, c_ = sess.run([update_a, update_b, update_c])

我想知道变量赋值的时间。 哪个是正确的:“update_x - > assign x - > ... - > udpate_z - > assign z”或“update_x - > udpate_y - > udpate_z - > assign a,b,c”? (其中(x,y,z)是(a,b,c)的排列) 另外,如果有一种方法可以实现后一种分配(在完成列表中的所有操作后完成分配),请告诉我如何实现它。

2 个答案:

答案 0 :(得分:15)

三个操作update_aupdate_bupdate_c在数据流图中没有相互依赖关系,因此TensorFlow可以选择以任何顺序执行它们。 (在当前的实现中,它们可能会在不同的线程上并行执行。)第二个缺点是默认情况下缓存了变量的读取,因此在程序中update_b中分配的值(即c + a)可以使用a的原始值或更新值,具体取决于首次读取变量的时间。

如果要确保按特定顺序执行操作,可以使用with tf.control_dependencies([...]):块强制执行在块中创建的操作在列表中指定的操作之后发生。您可以在with tf.control_dependencies([...]):块中使用tf.Variable.read_value()来显示变量的显式点。

因此,如果您想确保在update_a之前发生update_b并在update_b之前发生update_c,您可以这样做:

update_a = tf.assign(a, b + c)

with tf.control_dependencies([update_a]):
  update_b = tf.assign(b, c + a.read_value())

with tf.control_dependencies([update_b]):
  update_c = tf.assign(c, a.read_value() + b.read_value())

答案 1 :(得分:1)

基于你的这个例子,

v = tf.Variable(0)
c = tf.constant(3)
add = tf.add(v, c)
update = tf.assign(v, add)
mul = tf.mul(add, update)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    res = sess.run([mul, mul])
    print(res)

输出:[9,9]

你得到[9, 9],这实际上就是我们要求它做的事情。可以这样想:

在运行期间,从列表中取出mul后,它会查找此定义并找到tf.mul(add, update)。现在,它需要add的值,这会导致tf.add(v, c)。因此,它插入vc的值,得到add的值为3.

好的,现在我们需要update的值,其定义为tf.assign(v, add)。我们有add的值(它现在计算为3)& v。因此,它会将v的值更新为3,这也是update的值。

现在,addupdate的值均为3.因此,乘法在mul中产生9。

根据我们得到的结果,我认为,对于列表中的下一个项目(操作),它只返回刚刚计算的mul值。我不确定它是再次执行这些步骤还是只返回它为mul计算的相同(缓存?)值,意识到我们已经结果或这些操作并行发生(对于每个列表中的元素)。也许@mrry或@YaroslavBulatov可以对这部分发表评论吗?

引用@ mrry的评论:

  

当您拨打sess.run([x, y, z])一次时,TensorFlow会执行这些张量依赖仅一次的每个操作(除非您的图表中有tf.while_loop())。如果一个张量在列表中出现两次(如示例中的mul),TensorFlow将执行一次并返回结果的两个副本。要多次运行作业,您必须多次致电sess.run(),或使用tf.while_loop()在图表中添加循环。