在Tensorflow中逐行更新二维变量

时间:2018-09-09 19:52:12

标签: tensorflow

是否可以使用tf.scatter_update从二维tf.variable更新一行。这个想法是,变量位于tf.while_loop内,并且在每次迭代中,所选行都会用其他内容进行更新。这个想法是:

a = [[1, 2, 3],
     [4, 5, 6],
     [7, 8, 9]]

我想做这样的事情

a = tf.scatter_update(a, selected_row, updated_row)

1 个答案:

答案 0 :(得分:0)

此代码从另一个张量逐行更新一个张量。请注意,我没有考虑任何性能影响。

关键是我在其中使用变量(我认为该变量会在循环内成为张量)并对其进行递增。

a = tf.scatter_update(a,v,b[i])

v = i + 1

代码是这个。

import tensorflow as tf

def cond(size, i, v):
    return tf.less(i,size)

def body(size, i, v):

    a = tf.get_variable("a",
    [3,3],dtype=tf.int32,initializer=tf.constant_initializer([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
    b = tf.get_variable("b",[3,3],dtype=tf.int32,initializer=tf.constant_initializer([[10, 11, 12], [13, 14, 15], [16, 17, 18]]))

    a = tf.scatter_update(a,v,b[i])

    v = i + 1

    tf.get_variable_scope().reuse_variables()
    with tf.control_dependencies([a]):
        return (size, i+1, v)

with tf.Session() as sess:

    i = tf.constant(0)
    v = tf.Variable(0)
    size = tf.constant(3)
    _,i,v = tf.while_loop(cond,
                    body,
                    [size, i, v])

    a = tf.get_variable("a",[3,3],dtype=tf.int32)

    init = tf.initialize_all_variables()
    sess.run(init)

    print(sess.run([v,a,i]))

输出为

  

[3,array([[10,11,12],          [13,14,15],          [16、17、18]]),3]