TensorArray未更新tf.cond()中的值

时间:2018-03-19 14:52:04

标签: python tensorflow machine-learning

使用Tensorflow 1.3和python 2.7

TensorArray未更新tf.cond()中的值。

import tensorflow as tf
temp1 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
temp2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
temp1 = temp1.write(temp1.size(), tf.constant(1.))
temp2 = temp2.write(temp2.size(), tf.constant(10.))
cond=tf.convert_to_tensor([tf.constant(True),tf.constant(True),tf.constant(False),tf.constant(True),tf.constant(False)])

def if_body(t1, t2):
    return t1, t2

def else_body(t1, t2, i):
    m1 = tf.map_fn(lambda index: tf.cast(i, dtype=tf.float32)*tf.constant(2.), tf.range(t1.size()+tf.constant(1)), dtype=tf.float32)
    m2 = tf.map_fn(lambda index: tf.cast(i, dtype=tf.float32)*tf.constant(5.), tf.range(t2.size()+tf.constant(1)), dtype=tf.float32)
    ta1 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
    ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
    t1 = ta1.unstack(m1)
    t2 = ta2.unstack(m2)
    # 1st return with t1 = [4.,4.], t2 = [10.,10.], i = 2.
    # end return with t1 = [8.,8.,8.], t2 = [20.,20., 20.], i = 4.
    return t1, t2 

def loop_body(i, t1, t2, cond):
    t1, t2 = tf.cond(cond[i], # for i = 2 and i = 4 condition is False 
            lambda: if_body(t1, t2),
            lambda: else_body(t1, t2, i) #1st call with else_body([1.],[10.],2), 2nd call with else_body([4.,4.],[10.,10.],4)
            )
    i = tf.add(i, tf.constant(1))
    return i, t1, t2, cond

_, temp1, temp2, _ = tf.while_loop(lambda i, temp1, temp2, cond:tf.less(i, tf.constant(5)), loop_body, [tf.constant(0), temp1, temp2, cond])

# Expected output
# temp1 = [8. 8. 8.]
# temp2 = [20. 20. 20.]

with tf.Session() as ss:
    print(ss.run([temp1.identity().stack(), temp2.identity().stack()]))

此处TensorArray temp1和temp2未更新。如果我们直接修改else_body()中的t1和t2,它会反映在temp1和temp2中。将else_body()替换为波纹管代码以对此进行舍入。

def else_body(t1, t2, i):
    t1 = t1.write(t1.size(), tf.cast(i, dtype = tf.float32))
    t2 = t2.write(t2.size(), tf.cast(i, dtype = tf.float32))    
    return t1, t2 

当我们创建一个新的TensorArray并进行撤回时,仍然从最初的TensorArray中获取值,而不是从新的TensorArray中获取值。

0 个答案:

没有答案