为什么我们要做批量矩阵矩阵产品?

时间:2018-06-12 22:23:42

标签: deep-learning pytorch seq2seq

我正在关注Pytorch seq2seq tutorial并使用torch.bmm方法,如下所示:

attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                         encoder_outputs.unsqueeze(0))

我理解为什么我们需要增加注意力和编码器输出。

我不太明白的是我们在这里需要bmm方法的原因。 torch.bmm文件说

  

执行以batch1和batch2存储的矩阵的批处理矩阵 - 矩阵乘积。

     

batch1和batch2必须是3-D张量,每个张量包含相同数量的矩阵。

     

如果batch1是(b×n×m)张量,batch2是a(b×m×p)张量,out将是a(b×n×p)张量。

enter image description here

3 个答案:

答案 0 :(得分:4)

在seq2seq模型中,编码器以迷你批次的形式对输入序列进行编码。例如,输入为B x S x d,其中B是批量大小,S是最大序列长度,d是单词嵌入维度。然后编码器的输出为B x S x h,其中h是编码器的隐藏状态大小(即RNN)。

现在解码时(训练期间) 输入序列一次一个,因此输入为B x 1 x d,解码器产生一个形状张量B x 1 x h。现在计算上下文向量,我们需要将此解码器隐藏状态与编码器的编码状态进行比较。

因此,请考虑您有两个形状为T1 = B x S x hT2 = B x 1 x h的张量。因此,如果您可以按如下方式进行批量矩阵乘法。

out = torch.bmm(T1, T2.transpose(1, 2))

基本上,您将形状B x S x h的张量与形状张量B x h x 1相乘,这将导致B x S x 1这是每批次的注意力。

这里,注意力权重B x S x 1表示解码器的当前隐藏状态和编码器的所有隐藏状态之间的相似性得分。现在,您可以通过首先调换来使注意权重乘以编码器的隐藏状态B x S x h,这将导致形状B x h x 1的张量。如果你在dim = 2处执行squeeze,你将得到一个形状张量B x h,这是你的上下文向量。

此上下文向量(B x h)通常连接到解码器的隐藏状态(B x 1 x h,挤压dim = 1)以预测下一个令牌。

答案 1 :(得分:2)

上图中描述的操作发生在Seq2Seq模型的Decoder侧。意味着编码器输出已经按批次(小批量大小样本)表示。因此,attn_weights张量也应处于批处理模式。

因此,从本质上讲,张量zeroattn_weights的第一维(NumPy术语中的encoder_outputs轴)是小批量大小的样本数即可。因此,我们需要torch.bmm这两个张量。

答案 2 :(得分:2)

@wasiahmad关于seq2seq的一般实现是正确的,在提到的教程中没有批处理(B = 1),并且bmm只是工程过度,可以安全地替换为{{1} },并具有完全相同的模型质量和性能。自己看看,替换成这个:

matmul

与此:

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)

然后运行笔记本电脑。


此外,请注意,虽然@wasiahmad将编码器输入称为 attn_applied = torch.matmul(attn_weights, encoder_outputs) output = torch.cat((embedded[0], attn_applied), 1) ,但在pytorch 1.7.0中,作为编码器主要引擎的GRU希望输入格式为{{1 }} 默认。如果要使用@wasiahmad格式,请传递B x S x d标志。