在state_dict中加载Pytourch 3.0模型意外键“ module.features.0.weight”时出现问题

时间:2018-10-23 08:30:31

标签: python-3.x pytorch onnx

我正在尝试加载使用Pytorch训练过的模型, 但我不断收到以下错误:

  

文件“ convert.py”,第12行,在       model.load_state_dict(torch.load('model / model_vgg2d_2.pth'))文件   “ /usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py”,   在load_state_dict中的第490行       .format(name))KeyError:“ state_dict中出现意外的键“ module.features.0.weight”

下面是我的代码:

import torch.onnx
import torch.nn as nn

class TempModel(nn.Module):
    def __init__(self):
        super(TempModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 5, (3, 3))
    def forward(self, inp):
        return self.conv1(inp)

model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")

我正在与用于训练模型(具有多个GPU)的同一台机器上工作。 有什么想法我做错了吗?

1 个答案:

答案 0 :(得分:-1)

在加载state_dict时,您需要使其成为相同模型的state_dict:您无法将VGG模型的state_dict加载到完全不同的模型中BasicModel


旧答案
您保存的模型没有应用nn.DataParallel到模型,现在您要在添加模型之后尝试加载。试试

model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model)  # parallel AFTER load