pytorch模型加载和预测,AttributeError:'dict'对象没有属性'predict'

时间:2019-05-06 09:47:55

标签: python-3.x machine-learning pytorch

model = torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')
results, labels = predict_function(model, dev_data, version)

> /home/ofsdms/san_mrc/my_utils/data_utils.py(34)predict_squad()
-> phrase, spans, scores = model.predict(batch)
(Pdb) n
AttributeError: 'dict' object has no attribute 'predict'

如何加载pytorch模型的已保存检查点,并将其用于预测。我将模型保存在.pt扩展名中

1 个答案:

答案 0 :(得分:1)

您保存的检查点通常是state_dict:包含已训练权重值的字典-但不是是网络的实际体系结构。网络的实际计算图/体系结构被描述为python类(源自nn.Module)。
要使用经过训练的模型,您需要:

  1. 从实现计算图的类实例化一个model
  2. 将保存的state_dict加载到该实例:

    model.load_state_dict(torch.load('/home/ofsdms/san_mrc/checkpoint/best_v1_checkpoint.pt', map_location='cpu')