我正在尝试加载使用Pytorch训练过的模型, 但我不断收到以下错误:
文件“ convert.py”,第12行,在 model.load_state_dict(torch.load('model / model_vgg2d_2.pth'))文件 “ /usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py”, 在load_state_dict中的第490行 .format(name))KeyError:“ state_dict中出现意外的键“ module.features.0.weight”
下面是我的代码:
import torch.onnx
import torch.nn as nn
class TempModel(nn.Module):
def __init__(self):
super(TempModel, self).__init__()
self.conv1 = nn.Conv2d(3, 5, (3, 3))
def forward(self, inp):
return self.conv1(inp)
model = nn.DataParallel(TempModel())
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model_onnx/model_vgg2d_0.onnx")
我正在与用于训练模型(具有多个GPU)的同一台机器上工作。 有什么想法我做错了吗?
答案 0 :(得分:-1)
在加载state_dict
时,您需要使其成为相同模型的state_dict
:您无法将VGG模型的state_dict
加载到完全不同的模型中BasicModel
。
旧答案
您保存的模型没有应用nn.DataParallel
到模型,现在您要在添加模型之后尝试加载。试试
model = TempModel()
model.load_state_dict(torch.load('model/model_vgg2d_2.pth'))
model = nn.DataParallel(model) # parallel AFTER load