为定制模型加载state_dict时出错

时间:2019-04-24 08:48:30

标签: pytorch

加载模型权重时遇到问题。这是模型的某些部分

class InceptionV4(nn.Module):

   def __init__(self, num_classes=1001):
       super(InceptionV4, self).__init__()
       # Special attributs
       self.input_space = None
       self.input_size = (299, 299, 3)
       self.mean = None
       self.std = None
       # Modules
       self.features = nn.Sequential(
           BasicConv2d(3, 32, kernel_size=3, stride=2),
           BasicConv2d(32, 32, kernel_size=3, stride=1),
           BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
           Mixed_3a(),
           Mixed_4a(),
           Mixed_5a(),
           Inception_A(),
           Inception_A(),
           Inception_A(),
           ...
       )
       self.avg_pool = nn.AvgPool2d(8, count_include_pad=False)
       self.last_linear = nn.Linear(1536, num_classes)

我尝试保存权重,例如torch.save(model.state_dict(), weight_name),然后再次重新加载model.load_state_dict(torch.load(weight_name)) 但出现以下错误:

Missing key(s) in state_dict: "features.0.conv.weight", "features.0.bn.weight", "features.0.bn.bias", "features.0.bn.running_mean", "features.0.bn.running_var", "features.1.conv.weight", "features.1.bn.weight", "features.1.bn.bias", "features.1.bn.running_mean", "features.1.bn.running_var", "features.2.conv.weight", "features.2.bn.weight

还有:

Unexpected key(s) in state_dict: "conv.0.conv1.0.weight", "conv.0.conv1.0.bias", "conv.0.conv1.2.weight", "conv.0.conv1.2.bias", "conv.0.conv1.2.running_mean", "conv.0.conv1.2.running_var", "conv.0.conv1.2.num_batches_tracked", "conv.0.conv2.0.weight", "conv.0.conv2.0.bias", "conv.0.conv2.2.weight", "conv.0.conv2.2.bias", "conv.0.conv2.2.running_mean", "conv.0.conv2.2.running_var", "conv.0.conv2.2.num_batches_tracked", "conv.1.conv1.0.weight", "conv.1.conv1.0.bias", "conv.1.conv1.2.weight", "conv.1.conv1.2.bias", "conv.1.conv1.2.running_mean", "conv.1.conv1.2.running_var", "conv.1.conv1.2.num_batches_tracked

对此有任何提示吗?预先感谢。

1 个答案:

答案 0 :(得分:1)

我几次遇到这个问题。该错误表明您的模型state_dict与您加载的pre-trained weights的名称不同。

我在Inception_v4模型动物园中看不到torchvision的预训练模型,因此很难准确地确定您的InceptionV4类在哪里有字典不匹配的问题

无论您从何处获得pre-trained文件,但关键是要定义与pre-trained模型代码相同的模型,并且可以顺利加载权重文件。

以下是一些代码与模型不同的指标:


# change self.features -> self.conv: This helps in solving mismatched names.

self.conv = nn.Sequential(...)


# Google how to change the BatchNorm in your current pytorch version 
# and  the older pytorch version which the pretrained model was defined.

conv.1.conv1.2.num_batches_tracked  # it is deprecated in pytorch version 0.4 or newer

提示是:


# Define your model (or parts you want to reuse) the same as the original 

希望这会有所帮助:)