通过pytorch中带有线性输出层的RNN发送的填充批次的掩盖和计算损失

时间:2019-12-11 19:25:31

标签: pytorch mini-batch

尽管是典型的用例,但我找不到一个简单明了的指南来说明当通过RNN发送时,在pytorch中填充的minibatch上计算损失的规范方法是什么。

我认为规范的管道可能是:

1)pytorch RNN期望填充形状为(max_seq_len,batch_size,emb_size)的批处理张量

2)因此,我们给出了一个嵌入层,例如该张量:

tensor([[1, 1],
        [2, 2],
        [3, 9]])

9是填充索引。批处理大小为2。嵌入层将使其具有形状(max_seq_len,batch_size,emb_size)。批次中的顺序是降序排列,因此我们可以打包。

3)我们应用pack_padded_sequence,我们应用RNN,最后我们应用pad_packed_sequence。此时,我们有(max_seq_len,batch_size,hidden_​​size)

4)现在,我们将线性输出层应用于结果,并假设为log_softmax。因此,最后我们得到了一系列形状分数的张量:(max_seq_len,batch_size,linear_out_size)

我应该如何从此处计算损失,以屏蔽填充部分(具有任意目标)?谢谢!

1 个答案:

答案 0 :(得分:0)

我认为PyTocrh Chatbot Tutorial可能对您有帮助。

基本上,您可以计算有效输出值的掩码(填充无效),并使用该掩码来仅计算那些值的损耗。

请参阅教程页面上的outputVarmaskNLLLoss方法。为了方便起见,我在此处复制了代码,但是您确实需要在所有代码的上下文中查看它。

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()