如何修改tf.nn.embedding_lookup()的返回张量?

时间:2018-05-09 17:30:40

标签: python tensorflow machine-learning embedding tensor

我想使用scatter_nd_update来更改从tf.nn.embedding_lookup()返回的张量的内容。但是,返回的张量不可变,scatter_nd_update()需要一个可变张量作为输入。 我花了很多时间试图找到解决方案,包括使用gen_state_ops._temporary_variable和使用tf.sparse_to_dense,但遗憾的是都失败了。

我想知道它有一个美丽的解决方案吗?

with tf.device('/cpu:0'), tf.name_scope("embedding"):
            self.W = tf.Variable(
                tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
                name="W")
            self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
            updates = tf.constant(0,shape=[embedding_size])
            for i in range(1,sequence_length - 2):
                indices = [None,i]
                tf.scatter_nd_update(self.embedded_chars,indices,updates)
            self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

2 个答案:

答案 0 :(得分:1)

tf.nn.embedding_lookup只返回较大矩阵的切片,因此最简单的解决方案是更新 矩阵本身的值,在您的情况下为self.W

self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)

由于它是变量,因此符合tf.scatter_nd_update。请注意,您无法只更新任何张量,只能更新变量

另一种选择是为所选切片创建一个新变量,为其指定self.embedded_chars并在之后执行更新。

警告:在这两种情况下,您都会阻止渐变来训练嵌入矩阵,因此请仔细检查是否覆盖了所学习的值实际上是您想要的。

答案 1 :(得分:0)

这个问题源于不能清楚地理解张量流上下文中的张量和变量。后来随着对张量的更多了解,我想到的解决方案是:

   with tf.device('/cpu:0'), tf.name_scope("embedding"):
        self.W = tf.Variable(
            tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
            name="W")
        self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
        for i in range(0,sequence_length - 1,2):
            self.tslice = tf.slice(self.embedded_chars,[0,i,0],[0,1,128])
            self.tslice2 = tf.slice(self.embedded_chars,[0,i+1,0],[0,1,128])
            self.tslice3 = tf.slice(self.embedded_chars,[0,i+2,0],[0,1,128])
            self.toffset1 = tf.subtract(self.tslice,self.tslice2)
            self.toffset2 = tf.subtract(self.tslice2,self.tslice3)
            self.tconcat = tf.concat([self.toffset1,self.toffset2],1)
        self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

使用的函数,tf.slice,tf.subtract,tf.concat都接受张量作为输入。只需避免使用需要变量作为输入的函数,如tf.scatter_nd_update。