在pytorch训练期间best_state随模型变化

时间:2019-06-10 12:44:56

标签: python pytorch ordereddictionary

我想保存最佳模型,然后在测试期间加载它。因此,我使用了以下方法:

def train():  
    #training steps …  
    if acc > best_acc:  
        best_state = model.state_dict()  
        best_acc = acc
    return best_state  

然后,在我使用的主要功能中:

model.load_state_dict(best_state)  

恢复模型。

但是,我发现best_state始终与训练期间的最后一个状态相同,而不是最佳状态。有谁知道原因以及如何避免它?

顺便说一句,我知道我可以使用torch.save(the_model.state_dict(), PATH)然后通过以下方式加载模型 the_model.load_state_dict(torch.load(PATH))。 但是,我不想将参数保存到文件中,因为训练和测试功能在一个文件中。

2 个答案:

答案 0 :(得分:1)

model.state_dict()OrderedDict

from collections import OrderedDict

您可以使用:

from copy import deepcopy

解决问题

相反:

best_state = model.state_dict() 

您应该使用:

best_state = copy.deepcopy(model.state_dict())

深(而不是浅)副本使可变的OrderedDict实例不会随着best_state的变化而变化。

您可以检查我的other answer,将状态字典保存在PyTorch中。

答案 1 :(得分:0)

保存模型状态时,应将以下内容保存在网络中

1)优化器状态和 2)模型的状态字典

您可以在类模型中定义以下一种方法

def save_state(state,filename):
    torch.save(state,filename)

''' 保存状态时,请执行以下操作: '''

Model model //for example  
model.save_state({'state_dict':model.state_dict(), 'optimizer': optimizer.state_dict()}) 

保存的模型将存储为model.pth.tar(例如)

现在在加载期间执行以下步骤,

checkpoint = torch.load('model.pth.tar')         

model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

希望这会对您有所帮助。