我正在编写一个使用预训练的BERT
的问答系统,上面有一个线性层和一个softmax
层。遵循网上可用的模板时,一个示例的标签通常仅包含一个answer_start_index
和一个answer_end_index
。例如,从Huggingface
实例化SQUADFeatures
对象时:
```
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.cls_index = cls_index
self.p_mask = p_mask
self.example_index = example_index
self.unique_id = unique_id
self.paragraph_len = paragraph_len
self.token_is_max_context = token_is_max_context
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.qas_id = qas_id
```
但是,在我自己的数据集中,我有一些示例,其中在上下文中的多个位置找到了答案词,即,可能有几个正确的跨度构成了答案。
我的问题是我不知道如何管理此类示例?在网络标签上可用的模板中,通常在列表中,例如:
在我看来,这可能是这样的:
换句话说,对于每个示例,我没有一个包含一个标签的列表,但是对于一个示例,我没有一个包含单个标签或“标签”列表的列表,即由列表组成的列表。
遵循其他模板时,该过程的下一步是:
```
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
token_type_ids = torch.cat(token_type_ids, dim=0)
span_starts = torch(span_starts) #Something like this
span_ends = torch(span_ends) #Something like this
```
但是,这当然会引起一个错误,因为我的span_start列表和span_end列表不仅只包含单个项目,而且有时还包含列表中的一个列表。
有人对我如何解决这个问题有想法吗?我应该仅使用仅一个范围构成上下文中答案的示例吗?
如果我解决了火炬错误,反向传播/评估/损失计算是否仍然有效?
谢谢! / B
答案 0 :(得分:0)
您已检查代码
from transformers import BertTokenizer, BertForQuestionAnswering
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text)
input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"]
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
assert answer == "a nice puppet"
我不确定这是否是最好的方法,但是您可以代替argmax来使用topk
,并检查它是否与正确的答案相对应。
t = torch.LongTensor([0,1,2,3,4,5,6,7,8,9])
t
_, indices = t.topk(4)
indices#([9, 8, 7, 6])