model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)
> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'
如何加载pytorch模型的已保存检查点,并将其用于预测。我将模型保存在.pt扩展名中
答案 0 :(得分:1)
您保存的检查点通常是state_dict
:包含已训练权重值的字典-但不是是网络的实际体系结构。网络的实际计算图/体系结构被描述为python类(源自nn.Module
)。
要使用经过训练的模型,您需要:
model
。 将保存的state_dict
加载到该实例:
model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')