如何将BERT的网络输出转换为可读文本?

时间:2019-08-06 14:09:32

标签: python nlp pytorch

我试图了解如何将BERT用于QnA,并找到了有关如何在PyTorch(here)上入门的教程。现在,我想使用这些片段开始学习,但是我不明白如何将输出投影回示例文本。

text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"

(...)

# Predict the start and end positions logits
with torch.no_grad():
    start_logits, end_logits = questionAnswering_model(tokens_tensor, segments_tensors)

# Or get the total loss
start_positions, end_positions = torch.tensor([12]), torch.tensor([14])

multiple_choice_loss = questionAnswering_model(
                         tokens_tensor,
                         segments_tensors,
                         start_positions=start_positions,
                         end_positions=end_positions)
  

start_logits(形状:[1,16]):张量([[0.0196,0.1578,0.0848,0.1333,-0.4113,-0.0241,-0.1060,-0.3649,             0.0955,-0.4644,-0.1548、0.0967,-0.0659、0.1055,-0.1488,-0.3649]]]

     

end_logits(形状:[1,16]):张量([[0.1828,-0.2691,-0.0594,-0.1618,0.0441,-0.2574,-0.2883,0.2526,            -0.0551,-0.0051,-0.1572,-0.1670,-0.1219,-0.1831,-0.4463、0.2526]]]

如果我的假设正确,则需要将start_logits和end_logits投影到文本上,但是我该如何计算呢?

此外,您是否有任何资源/指南/教程可以推荐您继续进入QnA(google-research / bert github和bert的论文除外)?

谢谢。

1 个答案:

答案 0 :(得分:1)

我认为您正在尝试使用Bert进行问答,答案是原始文本的跨度。在这种情况下,原始论文将其用于SQuAD数据集。 start_logits和end_logits的令牌的登录记录是答案的开始/结束,因此您可以使用argmax,它将是令牌在文本中的索引。

此NAACL教程是关于您从您链接的仓库的作者那里学习转移学习的https://colab.research.google.com/drive/1iDHCYIrWswIKp-n-pOg69xLoZO09MEgf#scrollTo=qQ7-pH1Jp5EG,它使用分类作为目标任务,但您仍然会发现它很有用。