我正在尝试使用xlnet开发一个问题回答模型。我正在使用变压器库。在将输入传递给模型时,有一次我遇到了无法解码的错误。相同的代码适用于ALBERT模型
# filename = "/home/sundararaman/Projects/data/model_input/Tesla_report"
filename = "tesla.txt"
fp = open(filename, 'r')
data = fp.readlines()
text = "India is my country. America is a continent"
def run_pred(question, text):
input_dict = tokenizer.encode_plus(question, text, return_tensors='pt', max_length=512)
input_ids = input_dict["input_ids"].tolist()
start_scores, end_scores = model(**input_dict)
start = torch.argmax(start_scores)
end = torch.argmax(end_scores)
all_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
answer = ''.join(all_tokens[start: end + 1]).replace('▁', ' ').strip()
answer = answer.replace('[SEP]', '')
return answer if answer != '[CLS]' and len(answer) != 0 else ''
config_class, model_class, tokenizer_class = \
XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer
model_name_or_path = "ahotrod/xlnet_large_squad2_512"
config = config_class.from_pretrained(model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True)
model = model_class.from_pretrained(model_name_or_path, config=config)
我得到的错误是这个
“ start_scores,end_scores = model(** input_dict) ValueError:太多值无法解包(预期2)“