使用Tensorflow 1.3和python 2.7
TensorArray未更新tf.cond()中的值。
import tensorflow as tf
temp1 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
temp2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
temp1 = temp1.write(temp1.size(), tf.constant(1.))
temp2 = temp2.write(temp2.size(), tf.constant(10.))
cond=tf.convert_to_tensor([tf.constant(True),tf.constant(True),tf.constant(False),tf.constant(True),tf.constant(False)])
def if_body(t1, t2):
return t1, t2
def else_body(t1, t2, i):
m1 = tf.map_fn(lambda index: tf.cast(i, dtype=tf.float32)*tf.constant(2.), tf.range(t1.size()+tf.constant(1)), dtype=tf.float32)
m2 = tf.map_fn(lambda index: tf.cast(i, dtype=tf.float32)*tf.constant(5.), tf.range(t2.size()+tf.constant(1)), dtype=tf.float32)
ta1 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
t1 = ta1.unstack(m1)
t2 = ta2.unstack(m2)
# 1st return with t1 = [4.,4.], t2 = [10.,10.], i = 2.
# end return with t1 = [8.,8.,8.], t2 = [20.,20., 20.], i = 4.
return t1, t2
def loop_body(i, t1, t2, cond):
t1, t2 = tf.cond(cond[i], # for i = 2 and i = 4 condition is False
lambda: if_body(t1, t2),
lambda: else_body(t1, t2, i) #1st call with else_body([1.],[10.],2), 2nd call with else_body([4.,4.],[10.,10.],4)
)
i = tf.add(i, tf.constant(1))
return i, t1, t2, cond
_, temp1, temp2, _ = tf.while_loop(lambda i, temp1, temp2, cond:tf.less(i, tf.constant(5)), loop_body, [tf.constant(0), temp1, temp2, cond])
# Expected output
# temp1 = [8. 8. 8.]
# temp2 = [20. 20. 20.]
with tf.Session() as ss:
print(ss.run([temp1.identity().stack(), temp2.identity().stack()]))
此处TensorArray temp1和temp2未更新。如果我们直接修改else_body()中的t1和t2,它会反映在temp1和temp2中。将else_body()替换为波纹管代码以对此进行舍入。
def else_body(t1, t2, i):
t1 = t1.write(t1.size(), tf.cast(i, dtype = tf.float32))
t2 = t2.write(t2.size(), tf.cast(i, dtype = tf.float32))
return t1, t2
当我们创建一个新的TensorArray并进行撤回时,仍然从最初的TensorArray中获取值,而不是从新的TensorArray中获取值。