带填充遮罩的TransformerEncoder

时间:2020-06-16 00:43:00

标签: pytorch transformer attention-model

我正在尝试使用src_key_padding_mask不等于none的方法来实现torch.nn.TransformerEncoder。假设输入的形状为src = [20, 95],二进制填充掩码的形状为src_mask = [20, 95],在填充标记的位置为1,在其他位置为0。我制作了一个具有8层的变压器编码器,每层包含一个带有8个头且隐藏尺寸为256的关注点:

layer=torch.nn.TransformerEncoderLayer(256, 8, 256, 0.1)
encoder=torch.nn.TransformerEncoder(layer, 6)
embed=torch.nn.Embedding(80000, 256)
src=torch.randint(0, 1000, (20, 95))
src = emb(src)
src_mask = torch.randint(0,2,(20, 95))
output =  encoder(src, src_mask)

但是出现以下错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-107-31bf7ab8384b> in <module>
----> 1 output =  encoder(src, src_mask)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, mask, src_key_padding_mask)
    165         for i in range(self.num_layers):
    166             output = self.layers[i](output, src_mask=mask,
--> 167                                     src_key_padding_mask=src_key_padding_mask)
    168 
    169         if self.norm:

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/transformer.py in forward(self, src, src_mask, src_key_padding_mask)
    264         """
    265         src2 = self.self_attn(src, src, src, attn_mask=src_mask,
--> 266                               key_padding_mask=src_key_padding_mask)[0]
    267         src = src + self.dropout1(src2)
    268         src = self.norm1(src)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    545             result = self._slow_forward(*input, **kwargs)
    546         else:
--> 547             result = self.forward(*input, **kwargs)
    548         for hook in self._forward_hooks.values():
    549             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/activation.py in forward(self, query, key, value, key_padding_mask, need_weights, attn_mask)
    781                 training=self.training,
    782                 key_padding_mask=key_padding_mask, need_weights=need_weights,
--> 783                 attn_mask=attn_mask)
    784 
    785 

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight, q_proj_weight, k_proj_weight, v_proj_weight, static_k, static_v)
   3250     if attn_mask is not None:
   3251         attn_mask = attn_mask.unsqueeze(0)
-> 3252         attn_output_weights += attn_mask
   3253 
   3254     if key_padding_mask is not None:

RuntimeError: The size of tensor a (20) must match the size of tensor b (95) at non-singleton dimension 2

我想知道是否有人可以帮助我解决这个问题。

谢谢

1 个答案:

答案 0 :(得分:5)

所需的形状显示在nn.Transformer.forward - Shape中(变压器的所有构造块都参考该形状)。与编码器相关的是:

  • src:(S,N,E)
  • src_mask:(S,S)
  • src_key_padding_mask:(N,S)

其中 S 是序列长度, N 批量大小, E 嵌入尺寸(特征数量)。

填充遮罩的形状应为 [95,20] ,而不是 [20,95] 。假设您的批处理大小为95,序列长度为20,但是如果不是这样,则必须转置src

此外,在调用编码器时,您没有指定src_key_padding_mask,而是指定src_mask,因为torch.nn.TransformerEncoder.forward的签名是:

forward(src, mask=None, src_key_padding_mask=None)

必须将填充掩码指定为关键字参数src_key_padding_mask而不是第二个位置参数。为避免混淆,应将您的src_mask重命名为src_key_padding_mask

src_key_padding_mask = torch.randint(0,2,(95, 20))
output =  encoder(src, src_key_padding_mask=src_key_padding_mask)