如何将tf.map_fn与tf.scatter_update一起正确使用?

时间:2018-07-03 12:08:50

标签: python-3.x tensorflow

所以,我是tensorflow的新手,我想更新一个变量,所以我有这样的代码:

import tensorflow as tf
vector = tf.Variable(tf.zeros([10, 1], dtype=tf.int32))


def func(i):
    return tf.scatter_update(vector, i, 1)


vector_updated = tf.map_fn(func, tf.constant([0, 1, 2, 3]))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(vector_updated))

现在,当我运行此代码时,我会收到4倍的更新向量,即shape =(4,10,1)。我有点理解为什么会这样,但是我不知道如何解决。我只希望接收一次更新的向量。

请记住,我想执行此“循环”操作,因为它可以帮助我编写其他代码,因此我不希望使用所有索引来更改scatter_update。

有人可以帮忙吗?

非常感谢。

0 个答案:

没有答案