如何在pytorch中的修改后的vgg19网络中加载预训练权重?

时间:2019-12-14 13:11:15

标签: vgg-net pre-trained-model

我正在尝试使用修改后的输入通道数来加载vgg19网络。输入通道的数量是4,这是我的情况,而且我正在将分类器更改为自己的分类器。我还从网络中删除了自适应平均池化层。如何在PyTorch中将预先训练的权重加载到模型的修改版本中?

说我的模型的修改版本在变量myModel中。如何将vgg19的预训练权重加载到相同的值?

1 个答案:

答案 0 :(得分:0)

选项1.如果要使用原始VGG19网络提供的原始预训练权重,则必须先加载权重,然后再修改网络。 预训练的权重是为原始网络定义的,因此它需要与输入通道匹配。 然后,您可以在开头添加一个额外的层作为输入层,并删除新网络中的池化层。

选项2。您可以分别加载除输入图层之外的所有图层的权重,因为这会导致尺寸不匹配。

在代码中看起来像这样-

  # corresp_name is a dict object with mapping for your given layer 
  # name and original models layer name
  p_dict = torch.load(Path.model_dir()) #p_dict is my_model
  s_dict = self.state_dict()
  for name in p_dict:
      if name not in corresp_name:
            continue
      s_dict[corresp_name[name]] = p_dict[name]
  self.load_state_dict(s_dict)