我目前正在为模型定制自定义前进方法。我正在使用一些运行VGG的教程代码。我对基准模型进行了几次运行,它似乎运行良好。之后,我使用以下命令替换了VGG的forward方法:
net.forward = types.MethodType(forward_vgg_new, net)
不幸的是,教程代码保存模型的方式是:
state = {
'net':net,
'acc':acc,
'epoch':epoch,
}
...
torch.save(state, ...)
虽然这适用于原始教程代码,但在获得以下代码后,不再适用于我的自定义模型:
AttributeError:“ VGG”对象没有属性“ forward_vgg_new”
此后,我从文档中了解到,最好保存模型的state_dict:
state = {
'net':net.state_dict(),
'acc':acc,
'epoch':epoch,
}
...
torch.save(state, ...)
虽然我将为以后的运行更改代码,但我想知道是否有可能挽救我已经训练的模型。我已经天真地尝试导入VGG类并向其中添加我的forward_vgg_new方法:
setattr(VGG, 'forward_vgg_new', forward_vgg_new)
在调用torch.load之前,但是它不起作用。
答案 0 :(得分:0)
为解决此问题,我直接进入VGG库并临时添加了我的函数,以便可以加载已保存的模型并仅保存其状态字典。恢复保存后,我将更改恢复到VGG库。不是解决问题的最优雅方法,但它确实有效。