我遇到的一些教程,使用随机初始化的嵌入矩阵进行描述,然后使用tf.nn.embedding_lookup
函数获取整数序列的嵌入。我的印象是,由于embedding_matrix
是通过tf.get_variable
获得的,优化程序会添加适当的操作来更新它。
我不明白的是,反向传播是如何通过查找函数发生的,这似乎很难而不是软。这个操作的梯度是多少?其中一个输入ID?
答案 0 :(得分:5)
嵌入矩阵查找在数学上等同于具有单热编码矩阵的点积(参见this question),这是一种平滑的线性运算。
例如,这里是索引3
的查找:
以下是渐变的公式:
...其中左侧是负对数似然的导数(即目标函数),x
是输入词,W
是嵌入矩阵,{{1是错误信号。
tf.nn.embedding_lookup
已经过优化,因此不会发生单热编码转换,但是backprop按照相同的公式工作。