当我尝试使用这段代码保存PyTorch模型时:
checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
我收到以下错误:
E:\PROGRAM FILES\Anaconda\envs\staj_projesi\lib\site-packages\torch\serialization.py:251: UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading.
...
"type " + obj.__name__ + ". It won't be checked "
Can't pickle local object 'trainModel.<locals>.Net'
当我尝试使用这段代码保存PyTorch模型时:
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
我没有收到任何错误,但是我想保存ANN类。我怎么解决这个问题?另外,我可以将模型的第一个结构保存在其他项目之前
答案 0 :(得分:1)
不能! torch.save
仅保存对象state_dict()
。
使用以下内容时:
checkpoint = {'model': Net(), 'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
您正在尝试保存模型本身,但是此数据保存在model.state_dict()
中,并且在使用state_dict
加载模型时,您应该首先启动模型对象。
这正是第二种方法正常工作的原因:
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, 'Checkpoint.pth')
我建议在以下链接中阅读有关如何正确保存/加载模型的pytorch文档: https://pytorch.org/tutorials/beginner/saving_loading_models.html