所以,我是tensorflow的新手,我想更新一个变量,所以我有这样的代码:
import tensorflow as tf
vector = tf.Variable(tf.zeros([10, 1], dtype=tf.int32))
def func(i):
return tf.scatter_update(vector, i, 1)
vector_updated = tf.map_fn(func, tf.constant([0, 1, 2, 3]))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(vector_updated))
现在,当我运行此代码时,我会收到4倍的更新向量,即shape =(4,10,1)。我有点理解为什么会这样,但是我不知道如何解决。我只希望接收一次更新的向量。
请记住,我想执行此“循环”操作,因为它可以帮助我编写其他代码,因此我不希望使用所有索引来更改scatter_update。
有人可以帮忙吗?
非常感谢。