Pytorch:使用自定义的VGG模型恢复网络,该模型保存不当

时间:2020-06-02 18:59:51

标签: pytorch

我目前正在为模型定制自定义前进方法。我正在使用一些运行VGG的教程代码。我对基准模型进行了几次运行,它似乎运行良好。之后,我使用以下命令替换了VGG的forward方法:

net.forward = types.MethodType(forward_vgg_new, net)

不幸的是,教程代码保存模型的方式是:

            state = {
                    'net':net,
                    'acc':acc,
                    'epoch':epoch,
            }
...
            torch.save(state, ...)

虽然这适用于原始教程代码,但在获得以下代码后,不再适用于我的自定义模型:

AttributeError:“ VGG”对象没有属性“ forward_vgg_new”

此后,我从文档中了解到,最好保存模型的state_dict:

            state = {
                    'net':net.state_dict(),
                    'acc':acc,
                    'epoch':epoch,
            }
...
            torch.save(state, ...)

虽然我将为以后的运行更改代码,但我想知道是否有可能挽救我已经训练的模型。我已经天真地尝试导入VGG类并向其中添加我的forward_vgg_new方法:

setattr(VGG, 'forward_vgg_new', forward_vgg_new)

在调用torch.load之前,但是它不起作用。

1 个答案:

答案 0 :(得分:0)

为解决此问题,我直接进入VGG库并临时添加了我的函数,以便可以加载已保存的模型并仅保存其状态字典。恢复保存后,我将更改恢复到VGG库。不是解决问题的最优雅方法,但它确实有效。