Pytorch:参数或大小更改后是否可以加载模型?

时间:2019-06-30 13:04:13

标签: python load pytorch

Pytorch模型(图形,权重和偏差)保存为:

torch.save(self.state_dict(), file)

并加载:

self.load_state_dict(torch.load(file))

但是,如果更改了参数,则该模型不会加载错误,例如:

RuntimeError: Error(s) in loading state_dict for LeNet5:
    size mismatch for conv1.weight:

是否可以将尺寸更改的模型加载到模型? 像在初始化(如果有更多权重)中那样填充其余权重,而如果有更少权重则剪辑(clip)吗?

1 个答案:

答案 0 :(得分:1)

没有自动的方法-因为您需要明确决定在不匹配的情况下该怎么做。

个人而言,当我需要在略有变化的模型上“强制”预训练的权重时。我发现使用state_dict本身是最方便的方法。

new_model = model( ... )  # construct the new model
new_sd = new_model.state_dict()  # take the "default" state_dict
pre_trained_sd = torch.load(file)  # load the old version pre-trained weights
# merge information from pre_trained_sd into new_sd
# ...
# after merging the state dict you can load it:
new_model.load_state_dict(new_sd)