在pytorch中注意seq2seq教程的错误?

时间:2019-05-02 19:16:54

标签: pytorch seq2seq

我正在Pytorch中编写序列神经网络的序列。在official Pytorch seq2seq tutorial中,有一个我无法理解/认为可能包含错误的注意力解码器代码。

它通过将此时的输出和隐藏状态串联在一起,然后乘以一个矩阵,得到大小等于输出序列长度的向量,从而计算每个时间步的注意力权重。请注意,这些注意力权重并不取决于编码器序列(在代码中命名为encoder_outputs),

此外,the paper cited in the tutorial列出了三个可用于计算注意力权重的评分函数(本文第3.1节)。这些功能都不是仅通过矩阵进行连接和相乘。

所以在我看来,本教程中的代码在其应用的函数和传递给该函数的参数上都是错误的。我想念什么吗?

2 个答案:

答案 0 :(得分:0)

本教程在您提到的Luong论文中简化了这些注意事项。

它仅使用线性层来组合输入嵌入和解码器RNN隐藏状态。有时这称为“基于位置”的注意,因为它不依赖于编码器输出。然后应用softmax并计算注意力权重,然后按正常过程进行操作。

这并不总是很糟糕,因为从编码器输出,注意力机制可能会参与到先前的令牌,然后注意力将不会单调,因此您的模型将失败。

要实现Luong论文中的注意,我建议在对解码器隐藏状态和编码器输出都应用线性层之后,使用“ concat”注意。然后,矩阵W_a会将这些串联的结果转换为您选择的任意维度,最后v_a是一个向量,它将转换为所需的上下文向量维度。

答案 1 :(得分:0)

在算法中,attn_weights取决于解码参数。 然后我们得到线性层(这里是10)的输出。这是注意力向量。 然后我们将其乘以encoder_outputs。因此,在每个时期,我们通过反向传播来更新attn_weights。从语言上讲,每次迭代都在反向学习。 让我举个例子:

我们的任务是将英语翻译成德语。

我想唱歌。 -> Ichmöchteein Lied singen。

在解码器中,singen动词在结尾。因此,我们的解码器attn_weights可以看到解码器输出,并学习应用输入编码的哪些部分。当将此值乘以encoder_outputs时,将得到的矩阵,该矩阵的必要点具有较高的值。 因此,实际上,这是在学习解码器看到德语的句子模式时, 它必须注意输入的哪些部分。所以我认为学习的方向是正确的。