Tensorflow tf.cond和多行lambda

时间:2018-01-22 03:34:21

标签: python tensorflow

我正在玩tf.condtf.while_loop。我有以下代码:

def cond_loop_test(t1, num_steps):
    def cond(i, t1, t2, t3, t4):
        return tf.less(i, num_steps)

    def body(i, t1, t2, t3, t4):
        t3, t4 = t3+1, t4+1

        def last_step(t1, t2):
            t1 = t1+2
            t2 = t2+2
            def f1(): return t1, t2
            return f1

        def other_steps(t1, t2):
            t1 = t1+1
            t2 = t2+1
            def f2(): return t1, t2
            return f2

        tf.cond(tf.equal(i+1, num_steps), last_step(t1, t2), other_steps(t1, t2))

        return [tf.add(i, 1), t1, t2, t3, t4]

    return tf.while_loop(cond, body, [0,
                                      t1,
                                      tf.zeros_like(t1),
                                      tf.zeros_like(5),
                                      tf.zeros_like(5)])

t1 = tf.ones(5)
num_steps = tf.constant(5)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    cond_loop = cond_loop_test(t1, num_steps)
    print(sess.run(cond_loop))

代码基本上运行了5个带有简单条件的循环,它应该在循环执行图形的过程中在不同的阶段运行不同的代码。

当我运行它时,我得到以下输出:

[5, array([1., 1., 1., 1., 1.], dtype=float32), array([0., 0., 0., 0., 0.], dtype=float32), 5, 5]

正如您所看到的,尽管在函数t1t2中修改了张量last_stepother_steps,但它们不会被修改 - 这些函数返回无参数函数tf.cond要求。

我期望的是这样的事情:

[5, array([8., 8.,8., 8., 8.], dtype=float32), array([7., 7., 7., 7., 7.], dtype=float32), 5, 5]

基本上,t1t2都会增加4次,然后在最后一次迭代循环中增加2次。

注意:我不能在这里使用多行lambda,因为如果我做了python会对UnboundLocalError: local variable referenced before assignment哭泣,因此丑陋的python hack只返回零arg函数返回修改后的张量(此时习惯于用这种语言进行黑客攻击)。

有人可以帮助我理解为什么在上面的代码中没有修改过t1t2张量?正如您所看到的那样,张量t3t4会被正确修改,唉,它们不会在条件函数中被修改。

0 个答案:

没有答案