我有一个称为model.pth
的预训练模型。我正在加载它并转换为ONNX:
class TempModel(nn.Module):
def dummyFunc():
print("dummy")
model = TempModel()
state_dict = torch.load("/pathToModel.model.pth")
model.load_state_dict(state_dict, strict=False)
dummy_input = torch.randn(10, 3, 256, 256)
torch.onnx.export(model, dummy_input, "myModel.onnx")
运行它时,出现此错误:
Traceback (most recent call last):
File "onnxconvert.py", line 48, in <module>
torch.onnx.export(model, dummy_input, "myModel.onnx")
File "/Users/sidyakinian/anaconda2/lib/python2.7/site- packages/torch/onnx/__init__.py", line 27, in export
return utils.export(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 104, in export
operator_export_type=operator_export_type)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 281, in _export
example_outputs, propagate)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 224, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 192, in _trace_and_get_graph_from_model
trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 197, in get_trace_graph
return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 252, in forward
out = self.inner(*trace_inputs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 487, in __call__
result = self._slow_forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in _slow_forward
result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 85, in forward
raise NotImplementedError
NotImplementedError
如果将model.load_state_dict(state_dict, strict=False)
更改为model.load_state_dict(state_dict)
,则会出现以下错误:
Traceback (most recent call last):
File "onnxconvert.py", line 45, in <module>
model.load_state_dict(state_dict)
File "/Users/myUserName/anaconda2/lib/python2.7/site- packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TempModel:
Unexpected key(s) in state_dict: "fc.weight", "fc.bias", "head_0.conv_0.bias", "head_0.conv_0.weight_orig", "head_0.conv_0.weight_u", "head_0.conv_0.weight_v", "head_0.conv_1.bias", "head_0.conv_1.weight_orig", "head_0.conv_1.weight_u", "head_0.conv_1.weight_v"...
有100多个意外的键,我只是将其中一些剪掉了。
似乎我应该在forward
中实现TempModel
方法,但是该模型有100多个参数,但是我没有创建它,所以我不确定如何精确地做到这一点。 / p>
我该怎么做才能成功加载和导出模型?请帮忙!