仅更新Tensorflow中单词嵌入矩阵的一部分

时间:2016-03-04 18:26:53

标签: tensorflow word-embedding

假设我想在训练期间更新预训练的字嵌入矩阵,有没有办法只更新字嵌入矩阵的子集?

我查看了Tensorflow API页面并找到了:

# Create an optimizer.
opt = GradientDescentOptimizer(learning_rate=0.1)

# Compute the gradients for a list of variables.
grads_and_vars = opt.compute_gradients(loss, <list of variables>)

# grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
# need to the 'gradient' part, for example cap them, etc.
capped_grads_and_vars = [(MyCapper(gv[0]), gv[1])) for gv in grads_and_vars]

# Ask the optimizer to apply the capped gradients.
opt.apply_gradients(capped_grads_and_vars)

但是,我如何将其应用于字嵌入矩阵。假设我这样做:

word_emb = tf.Variable(0.2 * tf.random_uniform([syn0.shape[0],s['es']], minval=-1.0, maxval=1.0, dtype=tf.float32),name='word_emb',trainable=False)

gather_emb = tf.gather(word_emb,indices) #assuming that I pass some indices as placeholder through feed_dict

opt = tf.train.AdamOptimizer(1e-4)
grad = opt.compute_gradients(loss,gather_emb)

如何使用opt.apply_gradientstf.scatter_update更新原始的embeddign矩阵? (另外,如果compute_gradient的第二个参数不是tf.Variable),则tensorflow会抛出错误

3 个答案:

答案 0 :(得分:17)

TL; DR: opt.minimize(loss)的默认实现,TensorFlow将为word_emb生成稀疏更新,仅修改{的行{1}}参加了前进传球。

tf.gather(word_emb, indices) op相对于word_emb的渐变是tf.IndexedSlices个对象(see the implementation for more details)。此对象表示稀疏张量,除word_emb选择的行外,其他位置均为零。拨打indices来电AdamOptimizer._apply_sparse(word_emb_grad, word_emb),拨打opt.minimize(loss) *,仅更新由tf.scatter_sub(word_emb, ...)选择的word_emb行。

另一方面,如果您要修改opt.compute_gradients(loss, word_emb)返回的indices,则可以对其tf.IndexedSlicesindices属性执行任意TensorFlow操作,创建一个可以传递给opt.apply_gradients([(word_emb, ...)])的新values。例如,您可以使用tf.IndexedSlices(如示例中)使用以下调用来限制渐变:

MyCapper()

同样,您可以通过创建具有不同索引的新grad, = opt.compute_gradients(loss, word_emb) train_op = opt.apply_gradients( [tf.IndexedSlices(MyCapper(grad.values), grad.indices)]) 来更改将要修改的索引集。

*通常,如果您只想更新TensorFlow中变量的一部分,可以使用tf.scatter_update(), tf.scatter_add(), or tf.scatter_sub() operators,分别设置,添加到(tf.IndexedSlices)或减去({{1}先前存储在变量中的值。

答案 1 :(得分:5)

由于您只想选择要更新的元素(而不是更改渐变),您可以执行以下操作。

indices_to_update为布尔张量,表示您希望更新的索引,并在链接中定义entry_stop_gradients,然后:

gather_emb = entry_stop_gradients(gather_emb, indices_to_update)

Source

答案 2 :(得分:0)

其实我也遇到过这样的问题。就我而言,我需要使用 w2v 嵌入训练模型,但嵌入矩阵中并不存在所有标记。因此,对于那些不在矩阵中的标记,我进行了随机初始化。当然,嵌入已经训练的令牌不应该更新,因此我想出了这样的解决方案:

class PartialEmbeddingsUpdate(tf.keras.layers.Layer):
def __init__(self, len_vocab, 
             weights,
            indices_to_update):
    super(PartialEmbeddingsUpdate, self).__init__()
    self.embeddings = tf.Variable(weights, name='embedding', dtype=tf.float32)
    self.bool_mask = tf.equal(tf.expand_dims(tf.range(0,len_vocab),1), tf.expand_dims(indices_to_update,0))
    self.bool_mask = tf.reduce_any(self.bool_mask,1)
    self.bool_mask_not = tf.logical_not(self.bool_mask)
    self.bool_mask_not = tf.expand_dims(tf.cast(self.bool_mask_not, dtype=self.embeddings.dtype),1)
    self.bool_mask = tf.expand_dims(tf.cast(self.bool_mask, dtype=self.embeddings.dtype),1)
    
def call(self, input):
    input = tf.cast(input, dtype=tf.int32)
    embeddings = tf.stop_gradient(self.bool_mask_not * self.embeddings) + self.bool_mask * self.embeddings
    return tf.gather(embeddings,input)

其中 len_vocab - 是您的词汇长度,权重 - 权重矩阵(其中一些不应更新)和 indexs_to_update - 应更新的标记的索引。之后,我应用了这一层而不是 tf.keras.layers.Embeddings。希望对遇到同样问题的大家有所帮助。