如何理解变压器中掩盖的多头注意力

时间:2019-09-27 02:40:48

标签: tensorflow deep-learning transformer attention-model

我目前正在研究转换器的代码,但是我无法理解解码器的屏蔽多头。论文说这是为了防止您看到生成的单词,但是如果生成单词之后的单词还没有生成,我不能理解,怎么看?

我尝试阅读转换器的代码(链接:https://github.com/Kyubyong/transformer)。代码显示掩码如下所示。它使用下三角矩阵进行遮罩,我不明白为什么。

padding_num = -2 ** 32 + 1
diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense()  # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)
paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)

1 个答案:

答案 0 :(得分:9)

在阅读Transformer paper之后,我也有同样的问题。我在互联网上找不到完整,详细的问题答案,所以我将尝试解释我对“蒙面多头注意”的理解。

简短的答案是-我们需要掩盖以使训练平行。并行化很好,因为它可以使模型训练得更快。

这是一个解释这个想法的例子。假设我们训练将“我爱你”翻译成德语。编码器以并行模式工作-它可以在恒定的步数(即步数不取决于输入序列的长度)内产生输入序列(“我爱你”)的矢量表示。

假设编码器产生数字11, 12, 13作为输入序列的向量表示。实际上,这些向量将更长,但为简单起见,我们使用了较短的向量。同样为简单起见,我们忽略了服务令牌,例如-序列的开头,-序列的结尾等。

在培训期间,我们知道翻译应为“ Ich liebe dich”(我们始终知道培训期间的预期输出)。假设“ Ich liebe dich”单词的预期矢量表示为21, 22, 23

如果我们以顺序模式进行解码器训练,则看起来就像是递归神经网络的训练。将执行以下顺序步骤:

  • 顺序操作#1。输入:11, 12, 13
    • 尝试预测21
    • 预测的输出将不完全是21,假设它将是21.1
  • 顺序操作2。输入:11, 12, 13,还有21.1作为前一个输出。
    • 尝试预测22
    • 预测的输出将不完全是22,假设它将是22.3
  • 顺序操作#3。输入11, 12, 13,还输入22.3作为前一个输出。
    • 尝试预测23
    • 预测的输出将不完全是23,假设它将是23.5

这意味着我们需要进行3次连续操作(通常情况下-每个输入都进行一次连续操作)。此外,每次下一次迭代时,我们都会累积误差。另外,我们只关注单个先前的输出,所以不会引起注意。

由于我们实际上知道预期的输出,因此我们可以调整过程并使之并行。无需等待上一步的输出。

  • 并行操作#A。输入:11, 12, 13
    • 尝试预测21
  • 并行操作#B。输入:11, 12, 13,还有21
    • 尝试预测22
  • 并行操作#C。输入:11, 12, 13,还有21, 22
    • 尝试预测23

此算法可以并行执行,并且不会累积错误。而且该算法可以吸引注意力(即查看所有先前的输入),因此可以在进行预测时获得有关上下文的更多信息。

这是我们需要遮罩的地方。训练算法知道整个预期输出(21, 22, 23)。对于每个并行操作,它都会隐藏(遮盖)此已知输出序列的一部分。

  • 执行#A时-隐藏(遮盖)整个输出。
  • 执行#B时-隐藏第二和第三输出。
  • 执行#C时-隐藏第3个输出。

掩膜本身的实现方式如下(来自original paper):

我们通过遮罩在扩大点产品关注度的范围内实现这一目标 out(设置为-∞)在softmax输入中的所有值 对应非法连接

注意:在推理(非训练)期间,解码器以顺序(非并行)模式工作,因为它最初不知道输出顺序。但这与RNN方法不同,因为Transformer推理仍使用自我关注并查看所有先前的输出(但不仅是先前的输出)。

注2:在某些材料中,我看到遮罩可用于非翻译应用程序的方式有所不同。例如,对于语言建模,可以使用遮罩从输入句子中隐藏一些单词,并且模型将在训练过程中尝试使用其他未遮掩的单词来预测它们(即学会理解上下文)。