我正在尝试使用修改后的输入通道数来加载vgg19网络。输入通道的数量是4,这是我的情况,而且我正在将分类器更改为自己的分类器。我还从网络中删除了自适应平均池化层。如何在PyTorch中将预先训练的权重加载到模型的修改版本中?
说我的模型的修改版本在变量myModel中。如何将vgg19的预训练权重加载到相同的值?
答案 0 :(得分:0)
选项1.如果要使用原始VGG19网络提供的原始预训练权重,则必须先加载权重,然后再修改网络。 预训练的权重是为原始网络定义的,因此它需要与输入通道匹配。 然后,您可以在开头添加一个额外的层作为输入层,并删除新网络中的池化层。
选项2。您可以分别加载除输入图层之外的所有图层的权重,因为这会导致尺寸不匹配。
在代码中看起来像这样-
# corresp_name is a dict object with mapping for your given layer
# name and original models layer name
p_dict = torch.load(Path.model_dir()) #p_dict is my_model
s_dict = self.state_dict()
for name in p_dict:
if name not in corresp_name:
continue
s_dict[corresp_name[name]] = p_dict[name]
self.load_state_dict(s_dict)