BertForMaskedLM的错误蒙版语言建模预测

时间:2020-04-05 16:45:26

标签: python torch huggingface-transformers bert-language-model

当我转到translators = 2.7.0时,我发现LM模型无法预测正确的掩码令牌。

我的测试代码如下:

# coding: utf-8
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, AutoModel, AutoTokenizer, AutoModelWithLMHead, ElectraModel, ElectraForMaskedLM



MODEL_PATH = 'Resources/bert-base-uncased/uncased_L-12_H-768_A-12/'

VOCAB = MODEL_PATH

print('== tokenizing ===')
tokenizer = BertTokenizer.from_pretrained(VOCAB)

# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)

masked_index = 6
tokenized_text[masked_index] = '[MASK]'


print('== Extracting hidden layer ===')
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
input_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

# ======== Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

model.eval()

# ======== predict tokens ========
print('== LM predicting ===')
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained(MODEL_PATH)
model.eval()

# Predict all tokens
predictions = model(tokens_tensor, segments_tensors)[0]

# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
print('predicted_token', predicted_token)

详细信息

(1)此类测试代码可与 pytorch_pretrained_bert 的早期版本正常工作。 但是现在看来,该模型可以预测随机令牌。

(2)当我将Electra模型加载到ElectraForMaskedLM中时,也会发生随机预测。

0 个答案:

没有答案