如何理解tf.control_dependencies?

时间:2017-11-12 11:45:12

标签: tensorflow

有一个非常奇怪的例子:

w = tf.get_variable("w", shape=(), dtype=tf.int32,
                    initializer=tf.constant_initializer(2))

reset = tf.assign(w, 0)
update = tf.assign(w, w + 3)
update = tf.Print(update, [update])
reset = tf.Print(reset, [reset])
def body(i,x):
    with tf.control_dependencies([update]):
        t = tf.identity(w)

    with tf.control_dependencies([reset]):
        y = tf.identity(t)
    return i+1, y
i, z = tf.while_loop(lambda i,z: i < 20, body, [0,0])
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    print(sess.run(z))

输出为5。但是如何解释呢? 我们可以看到reset未执行,update仅执行一次 利用tf.Print。但tf.while_loop将执行body 20次。太棒了。

更新

另一个奇怪的example

import tensorflow as tf

x = tf.Variable(0, dtype=tf.int32)

old_val = tf.identity(x)
with tf.control_dependencies([old_val]):
    new_val = tf.assign(x, x + 1)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    print(sess.run([old_val, new_val, x]))

输出为[1,1,1]。在看了github问题之后,我也很困惑。 sess.run()逐个执行,这意味着它会old_val,然后是new_val,然后是x

当它运行old_val时,它会获得0,当它运行new_val时,它会发现old_valnew_val的依赖关系,但{{1}已经运行了。因此它立即运行old_val并获得new_val,然后运行1,获取x

所以我认为它应该打印1,那有什么不对?

1 个答案:

答案 0 :(得分:2)

以下是您获得此结果的原因说明:

让我们按照您在会话中执行的节点的图表,让我们看看会发生什么。

sess.run(z)

z是tf.while_loop的第二个返回变量,因此让我们看看当我们执行tf.while_loop节点时会发生什么。

第一次迭代:

tf.while_loop执行body函数,直到谓词为真。因此,对于第一次电话:

body(i,x)调用 - &gt; body函数的主体没有引用x变量。因此

return i+1, y

while循环继续,现在将y作为x传递到正文中。

第二次迭代:

body(i, x) = run(body(_, y variable of the previous iteration))

现在Tensorflow需要解析y变量。

  1. y变量为tf.identity(t)
  2. tf.identity(t)必须在reset = tf.assign(w,0)之后执行。
  3. tf.identity(t)引用t。执行reset后,我们必须解决t,对其进行评估,然后执行y
  4. 因此:t = tf.identity(w) - &gt;仅在执行update

    之后

    因此按此顺序执行:update - &gt; t - &gt; reset - &gt; y

    评估产生:w = w + 3 -> w = 5; t = 5; w = 0; y = t = 5; return 5.

    副作用

    updatereset节点在body函数之外声明,这意味着它们只是2个独立节点,现在它们被标记为已执行 (概念上)。

    第三次迭代

    评估顺序与上一次迭代相同,但是:updatereset节点已经执行(因为执行的标志存在),因此tf.control_dependencies会跳过它们执行和Tensorflow仅执行ty

    因此:t = 5; y = 5; return 5

    正如您所看到的,从现在开始,您将始终获得5