我正在使用快速的Bert软件包来训练Bert模型。 快速的bert保存模型输出以下文件:
-Resources
-- config.json
-- pytorch_model.bin
-- specail_tokens_map.json
-- tokenizer_config.json
-- vocab.txt
-- events.out.tfevents.1601151257 (1).ffce4853f2c9
我正尝试通过快速bert加载模型并使用它进行预测:
from fast_bert.prediction import BertClassificationPredictor
def predictor(texts, MODEL_PATH=None):
"""
:param MODEL_PATH: path to trained model
:type texts: list
:param: texts: texts to run prediction on
"""
MODEL_PATH = '/content/Resources/'
LABEL_DATA = '/content/data/' # same path as when trained so it's valid
if MODEL_PATH is None:
raise LookupError("This Path is either wrong or No Trained Model exists")
predictor = BertClassificationPredictor(
model_path=MODEL_PATH,
label_path=LABEL_DATA, # location for labels.csv file
multi_label=False,
model_type='xlnet',
do_lower_case=False)
# Batch predictions
multiple_predictions = predictor.predict_batch(texts)
pprint(("Predicting texts accuracy",) multiple_predictions)
return multiple_predictions
因此,出现以下错误:
OSError: Unable to load weights from pytorch checkpoint file. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.