我正在玩tf.cond
和tf.while_loop
。我有以下代码:
def cond_loop_test(t1, num_steps):
def cond(i, t1, t2, t3, t4):
return tf.less(i, num_steps)
def body(i, t1, t2, t3, t4):
t3, t4 = t3+1, t4+1
def last_step(t1, t2):
t1 = t1+2
t2 = t2+2
def f1(): return t1, t2
return f1
def other_steps(t1, t2):
t1 = t1+1
t2 = t2+1
def f2(): return t1, t2
return f2
tf.cond(tf.equal(i+1, num_steps), last_step(t1, t2), other_steps(t1, t2))
return [tf.add(i, 1), t1, t2, t3, t4]
return tf.while_loop(cond, body, [0,
t1,
tf.zeros_like(t1),
tf.zeros_like(5),
tf.zeros_like(5)])
t1 = tf.ones(5)
num_steps = tf.constant(5)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
cond_loop = cond_loop_test(t1, num_steps)
print(sess.run(cond_loop))
代码基本上运行了5个带有简单条件的循环,它应该在循环执行图形的过程中在不同的阶段运行不同的代码。
当我运行它时,我得到以下输出:
[5, array([1., 1., 1., 1., 1.], dtype=float32), array([0., 0., 0., 0., 0.], dtype=float32), 5, 5]
正如您所看到的,尽管在函数t1
和t2
中修改了张量last_step
和other_steps
,但它们不会被修改 - 这些函数返回无参数函数tf.cond
要求。
我期望的是这样的事情:
[5, array([8., 8.,8., 8., 8.], dtype=float32), array([7., 7., 7., 7., 7.], dtype=float32), 5, 5]
基本上,t1
和t2
都会增加4次,然后在最后一次迭代循环中增加2次。
注意:我不能在这里使用多行lambda,因为如果我做了python会对UnboundLocalError: local variable referenced before assignment
哭泣,因此丑陋的python hack只返回零arg函数返回修改后的张量(此时习惯于用这种语言进行黑客攻击)。
有人可以帮助我理解为什么在上面的代码中没有修改过t1
和t2
张量?正如您所看到的那样,张量t3
和t4
会被正确修改,唉,它们不会在条件函数中被修改。