我的测试代码如下:
# coding: utf-8
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM, AutoModel, AutoTokenizer, AutoModelWithLMHead, ElectraModel, ElectraForMaskedLM
MODEL_PATH = 'Resources/bert-base-uncased/uncased_L-12_H-768_A-12/'
VOCAB = MODEL_PATH
print('== tokenizing ===')
tokenizer = BertTokenizer.from_pretrained(VOCAB)
# Tokenized input
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
tokenized_text = tokenizer.tokenize(text)
masked_index = 6
tokenized_text[masked_index] = '[MASK]'
print('== Extracting hidden layer ===')
# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
input_mask = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
# ======== Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
model.eval()
# ======== predict tokens ========
print('== LM predicting ===')
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained(MODEL_PATH)
model.eval()
# Predict all tokens
predictions = model(tokens_tensor, segments_tensors)[0]
# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])
print('predicted_token', predicted_token)
(1)此类测试代码可与 pytorch_pretrained_bert 的早期版本正常工作。 但是现在看来,该模型可以预测随机令牌。
(2)当我将Electra模型加载到ElectraForMaskedLM中时,也会发生随机预测。