如何正确使用tf.while_loop?

时间:2018-07-29 17:47:53

标签: tensorflow while-loop

我有如下测试代码,我该如何将索引i递增1并生成一个张量[[i, i]]来对其进行串联:

i0 = tf.constant(0)
m0 = tf.Variable(tf.zeros([1, 2], dtype=tf.int32))
first_set = tf.Variable(initial_value=True,dtype=tf.bool)

def body(i, m):
    def cond_true_fn():
        global first_set
        first_set = tf.assign(first_set, False)
        m = tf.assign(m0, [[i, i]])
        return [i + 1, m]
    def cond_false_fn():
        global m0
        m0 = tf.assign(m0, [[i,i]])
        return [i + 1, tf.concat([m, m0], axis=0)]

    return tf.cond(first_set, cond_true_fn, cond_false_fn)

def condi(i, m):
    return i < 2

r = tf.while_loop(condi, body, loop_vars=[i0,m0], shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])], back_prop=False)

with tf.Session() as sess:
    tf.initialize_all_variables().run()
    _r = sess.run([r])
    print(_r[0][0],_r[0][1])

但是结果出乎意料-> 2 [[1, 1]]
为什么_r[0][1]不是[[0, 0], [1, 1]]

0 个答案:

没有答案