我想保存最佳模型,然后在测试期间加载它。因此,我使用了以下方法:
def train():
#training steps …
if acc > best_acc:
best_state = model.state_dict()
best_acc = acc
return best_state
然后,在我使用的主要功能中:
model.load_state_dict(best_state)
恢复模型。
但是,我发现best_state始终与训练期间的最后一个状态相同,而不是最佳状态。有谁知道原因以及如何避免它?
顺便说一句,我知道我可以使用torch.save(the_model.state_dict(), PATH)
然后通过以下方式加载模型
the_model.load_state_dict(torch.load(PATH))
。
但是,我不想将参数保存到文件中,因为训练和测试功能在一个文件中。
答案 0 :(得分:1)
model.state_dict()
是OrderedDict
from collections import OrderedDict
您可以使用:
from copy import deepcopy
解决问题
相反:
best_state = model.state_dict()
您应该使用:
best_state = copy.deepcopy(model.state_dict())
深(而不是浅)副本使可变的OrderedDict实例不会随着best_state
的变化而变化。
您可以检查我的other answer,将状态字典保存在PyTorch中。
答案 1 :(得分:0)
保存模型状态时,应将以下内容保存在网络中
1)优化器状态和 2)模型的状态字典
您可以在类模型中定义以下一种方法
def save_state(state,filename):
torch.save(state,filename)
''' 保存状态时,请执行以下操作: '''
Model model //for example
model.save_state({'state_dict':model.state_dict(), 'optimizer': optimizer.state_dict()})
保存的模型将存储为model.pth.tar(例如)
现在在加载期间执行以下步骤,
checkpoint = torch.load('model.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
希望这会对您有所帮助。