如何在while循环内更改Tensor的单个值?
我知道我可以使用tf.Variable
来操纵tf.scatter_update(variable, index, value)
的单个值,但是在循环内部,我无法访问变量。有没有办法在while循环内操作Tensor
的给定值。
作为参考,这是我当前的代码:
my_variable = tf.Variable()
def body(i, my_variable):
[...]
return tf.add(i, 1), tf.scatter_update(my_variable, [index], value)
loop = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])
答案 0 :(得分:2)
受this post的启发,您可以使用稀疏张量将增量存储到要分配的值,然后使用加法来“设置”该值。例如。像这样(我在这里假设一些形状/值,但是将其概括为更高等级的张量应该很简单):
import tensorflow as tf
my_variable = tf.Variable(tf.ones([5]))
def body(i, v):
index = i
new_value = 3.0
delta_value = new_value - v[index:index+1]
delta = tf.SparseTensor([[index]], delta_value, (5,))
v_updated = v + tf.sparse_tensor_to_dense(delta)
return tf.add(i, 1), v_updated
_, updated = tf.while_loop(lambda i, _: tf.less(i, 5), body, [0, my_variable])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(my_variable))
print(sess.run(updated))
此打印
[1. 1. 1. 1. 1.]
[3. 3. 3. 3. 3.]