在PyTorch中加载和冻结一个模型并训练其他模型

时间:2020-07-20 15:33:43

标签: python tensorflow pytorch torchvision

我有一个模型A,其中包括三个子模型model1,model2,model3。

模型流程:model1-> model2-> model3

我在一个独立项目中训练了model1。

问题是训练模型A时如何使用预先训练的模型1?

现在,我尝试按以下步骤实施此操作:

我通过使用model1.load_state_dict(torch.load(model1.pth))加载model1的检查点,然后将model1的参数require_grad设置为False?

对吗?

1 个答案:

答案 0 :(得分:0)

是的,这是正确的。

按照解释的方式构建模型时,您所做的是正确的。

ModelA包含三个子模型-模型1,模型,模型3

然后使用model*.load_state_dict(torch.load(model*.pth))

加载每个模型的权重

然后为您要冻结的模型制作requires_grad=False

for param in model*.parameters():
    param.requires_grad = False

您还可以通过访问子模块来冻结特定图层的权重,例如,如果在model1中有一个名为fc的图层,则可以通过制作model1.fc.weight.requres_grad = False来冻结其权重。