tensorflow' ScatterNdUpdate'操作要求输入' ref'是一个可变张量(例如:a tf.Variable)

时间:2018-05-08 12:54:10

标签: tensorflow

我尝试使用scatter_nd_update更新张量,我的代码如下:

with tf.device('/cpu:0'), tf.name_scope("embedding"):
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name="W")
            self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
            updates = tf.constant(0,shape=[embedding_size])
            for i in range(1,sequence_length - 2):
                indices = [None,i]
                tf.scatter_nd_update(self.embedded_chars,indices,updates)
            self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

然而,错误说:

TypeError:' ScatterNdUpdate'操作要求输入' ref'是一个可变张量(例如:a tf.Variable)

我知道原因是self.embedded_chars作为scatter_nd_update的参数是不可变的。

我的问题是如何定义self.embedded_chars以便将其传递给scatter_nd_update函数?

感谢任何想法。

2 个答案:

答案 0 :(得分:0)

如果您只是计算稍后将使用的值,那么只需使用tf.scatter_nd,就没有理由应用更新。我建议阅读这篇文章,讨论张量流中的the various types of tensors,我在此不再赘述。

如果您不需要将此值从一次调用sess.run维持到下一次,那么tf.scatter_nd是正确的解决方案。如果您尝试将一次调用的值保持为sess.run,那么您需要首先为此值创建一个变量,然后应用assing:

mutable_variable = tf.Variable(<initial-value>, name=<optional-name>)
with tf.control_dependencies(cost):
  tf.scatter_nd_update(mutable_variable, indices, updates)

如果您遵循第二种方法,请注意我添加了tf.control_dependencies因为赋值操作没有任何依赖关系,这是分配操作的常见问题。如果没有依赖关系,它将不会在sess.run调用的正常过程中执行,除非您明确要求执行它,添加依赖项(在这种情况下就成本函数作为示例)手动添加依赖性,因此它将在计算cost的任何时候执行。

答案 1 :(得分:0)

此问题已在另一个相关问题中得到解答,可通过 How to modify the return tensor from tf.nn.embedding_lookup()?