保存并加载检查点pytorch

时间:2018-11-29 11:35:43

标签: python-3.x pytorch rnn checkpointing

我制作一个模型并将配置保存为:

def checkpoint(state, ep, filename='./Risultati/checkpoint.pth'):  
    if ep == (n_epoch-1):
        print('Saving state...')
        torch.save(state,filename)
checkpoint({'state_dict':rnn.state_dict()},ep) 

然后我要加载此配置:

state_dict= torch.load('./Risultati/checkpoint.pth')
    rnn.state_dict(state_dict)

当我尝试时,这是错误消息:

Traceback (most recent call last):
File "train.py", line 288, in <module>
rnn.state_dict(state_dict)
File "/home/marco/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 593, in state_dict
destination._metadata[prefix[:-1]] = dict(version=self._version)
AttributeError: 'dict' object has no attribute '_metadata'

我做错了什么地方

提前

1 个答案:

答案 0 :(得分:0)

您需要加载存储在您加载的字典中的rnn.state_dict()

rnn.load_state_dict(state_dict['state_dict'])

请查看load_state_dict方法以获取更多信息。