是否可以使用tf.scatter_update从二维tf.variable更新一行。这个想法是,变量位于tf.while_loop内,并且在每次迭代中,所选行都会用其他内容进行更新。这个想法是:
a = [[1, 2, 3],
[4, 5, 6],
[7, 8, 9]]
我想做这样的事情
a = tf.scatter_update(a, selected_row, updated_row)
答案 0 :(得分:0)
此代码从另一个张量逐行更新一个张量。请注意,我没有考虑任何性能影响。
关键是我在其中使用变量(我认为该变量会在循环内成为张量)并对其进行递增。
a = tf.scatter_update(a,v,b[i])
v = i + 1
代码是这个。
import tensorflow as tf
def cond(size, i, v):
return tf.less(i,size)
def body(size, i, v):
a = tf.get_variable("a",
[3,3],dtype=tf.int32,initializer=tf.constant_initializer([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
b = tf.get_variable("b",[3,3],dtype=tf.int32,initializer=tf.constant_initializer([[10, 11, 12], [13, 14, 15], [16, 17, 18]]))
a = tf.scatter_update(a,v,b[i])
v = i + 1
tf.get_variable_scope().reuse_variables()
with tf.control_dependencies([a]):
return (size, i+1, v)
with tf.Session() as sess:
i = tf.constant(0)
v = tf.Variable(0)
size = tf.constant(3)
_,i,v = tf.while_loop(cond,
body,
[size, i, v])
a = tf.get_variable("a",[3,3],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([v,a,i]))
输出为
[3,array([[10,11,12], [13,14,15], [16、17、18]]),3]