PyTorch模型保存错误:“无法腌制本地对象”

时间:2020-05-29 16:36:55

标签: python object save pytorch pickle

当我尝试使用这段代码保存PyTorch模型时:

checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

我收到以下错误:

    E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...

      "type " + obj.__name__ + ". It won't be checked "
    Can't pickle local object 'trainModel.<locals>.Net'

当我尝试使用这段代码保存PyTorch模型时:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

我没有收到任何错误,但是我想保存ANN类。我怎么解决这个问题?另外,我可以将模型的第一个结构保存在其他项目之前

1 个答案:

答案 0 :(得分:1)

不能! torch.save仅保存对象state_dict()

使用以下内容时:

checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

您正在尝试保存模型本身,但是此数据保存在model.state_dict()中,并且在使用state_dict加载模型时,您应该首先启动模型对象。

这正是第二种方法正常工作的原因:

checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')

我建议在以下链接中阅读有关如何正确保存/加载模型的pytorch文档: https://pytorch.org/tutorials/beginner/saving_loading_models.html