加载和冻结预训练模型以与新网络结合

时间:2020-03-09 05:54:50

标签: python python-3.x pytorch

我有一个预先训练的模型,想在它的基础上建立一个分类器。我正在尝试加载和冻结预训练模型的权重,并将其输出传递给我希望优化的新分类器。到目前为止,这里的内容是,我有点陷入TypeError: forward() missing 1 required positional argument: 'x'行中的nn.Sequential错误:

import model #model.py contains the architecture of the pretrained model

class Classifier(nn.Module):
    def __init__(self):
        ...
    def forward(self, x):
        ...

net = model.Model()
net.load_state_dict(checkpoint["net"])

for c in net.children():
    for param in child.parameters():
        params.requires_grad = False

model = nn.Sequential(nn.ModuleList(net()), Classifier())

2 个答案:

答案 0 :(得分:0)

TL; DR

model = nn.Sequential(nn.ModuleList(net), Classifier())

您正在用net.forward“呼叫” net(),而不是__init__Classifier class Classifier()方法

答案 1 :(得分:0)

在与PyTorch论坛的@ptrblck讨论之后,我终于解决了这个问题。该解决方案与Shai的答案类似,只是因为net包含model.Model类的实例,所以应该执行model = nn.Sequential(net, Classifier())而不调用nn.ModuleList()