我有如下测试代码,我该如何将索引i
递增1并生成一个张量[[i, i]]
来对其进行串联:
i0 = tf.constant(0)
m0 = tf.Variable(tf.zeros([1, 2], dtype=tf.int32))
first_set = tf.Variable(initial_value=True,dtype=tf.bool)
def body(i, m):
def cond_true_fn():
global first_set
first_set = tf.assign(first_set, False)
m = tf.assign(m0, [[i, i]])
return [i + 1, m]
def cond_false_fn():
global m0
m0 = tf.assign(m0, [[i,i]])
return [i + 1, tf.concat([m, m0], axis=0)]
return tf.cond(first_set, cond_true_fn, cond_false_fn)
def condi(i, m):
return i < 2
r = tf.while_loop(condi, body, loop_vars=[i0,m0], shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])], back_prop=False)
with tf.Session() as sess:
tf.initialize_all_variables().run()
_r = sess.run([r])
print(_r[0][0],_r[0][1])
但是结果出乎意料-> 2 [[1, 1]]
为什么_r[0][1]
不是[[0, 0], [1, 1]]
?