在上下文中有多个答案范围,BERT问题解答

时间:2020-05-25 14:08:08

标签: python pytorch bert-language-model question-answering squad

我正在编写一个使用预训练的BERT的问答系统,上面有一个线性层和一个softmax层。遵循网上可用的模板时,一个示例的标签通常仅包含一个answer_start_index和一个answer_end_index。例如,从Huggingface实例化SQUADFeatures对象时:

```
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.cls_index = cls_index
self.p_mask = p_mask

self.example_index = example_index
self.unique_id = unique_id
self.paragraph_len = paragraph_len
self.token_is_max_context = token_is_max_context
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map

self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.qas_id = qas_id
```

但是,在我自己的数据集中,我有一些示例,其中在上下文中的多个位置找到了答案词,即,可能有几个正确的跨度构成了答案。

我的问题是我不知道如何管理此类示例?在网络标签上可用的模板中,通常在列表中,例如:

  • [start_example1,start_example2,start_example3]
  • [end_example1,end_example2,end_example3]

在我看来,这可能是这样的:

  • [start_example1,[start_example2_1,start_example2_2],start_example3]
  • 当然也要结束

换句话说,对于每个示例,我没有一个包含一个标签的列表,但是对于一个示例,我没有一个包含单个标签或“标签”列表的列表,即由列表组成的列表。

遵循其他模板时,该过程的下一步是:

```
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
token_type_ids = torch.cat(token_type_ids, dim=0)
span_starts = torch(span_starts) #Something like this
span_ends = torch(span_ends) #Something like this
```

但是,这当然会引起一个错误,因为我的span_start列表和span_end列表不仅只包含单个项目,而且有时还包含列表中的一个列表。

有人对我如何解决这个问题有想法吗?我应该仅使用仅一个范围构成上下文中答案的示例吗?

如果我解决了火炬错误,反向传播/评估/损失计算是否仍然有效?

谢谢! / B

1 个答案:

答案 0 :(得分:0)

您已检查代码

from transformers import BertTokenizer, BertForQuestionAnswering
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text)
input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"]
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))

all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])

assert answer == "a nice puppet"

我不确定这是否是最好的方法,但是您可以代替argmax来使用topk,并检查它是否与正确的答案相对应。

t = torch.LongTensor([0,1,2,3,4,5,6,7,8,9])
t
_, indices = t.topk(4)
indices#([9, 8, 7, 6])