将PyTorch模型转换为ONNX时出现NotImplementedError

时间:2019-04-28 08:06:55

标签: machine-learning pytorch onnx

我有一个称为model.pth的预训练模型。我正在加载它并转换为ONNX:

class TempModel(nn.Module):
    def dummyFunc():
        print("dummy")

model = TempModel()
state_dict = torch.load("/pathToModel.model.pth")
model.load_state_dict(state_dict, strict=False)

dummy_input = torch.randn(10, 3, 256, 256)
torch.onnx.export(model, dummy_input, "myModel.onnx")

运行它时,出现此错误:

Traceback (most recent call last):
File "onnxconvert.py", line 48, in <module>
  torch.onnx.export(model, dummy_input, "myModel.onnx")
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-  packages/torch/onnx/__init__.py", line 27, in export
  return utils.export(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 104, in export
  operator_export_type=operator_export_type)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 281, in _export
  example_outputs, propagate)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 224, in _model_to_graph
  graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/onnx/utils.py", line 192, in _trace_and_get_graph_from_model
  trace, torch_out = torch.jit.get_trace_graph(model, args, _force_outplace=True)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 197, in get_trace_graph
  return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 489, in __call__
  result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/jit/__init__.py", line 252, in forward
  out = self.inner(*trace_inputs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 487, in __call__
  result = self._slow_forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in _slow_forward
  result = self.forward(*input, **kwargs)
File "/Users/sidyakinian/anaconda2/lib/python2.7/site-packages/torch/nn/modules/module.py", line 85, in forward
  raise NotImplementedError
NotImplementedError

如果将model.load_state_dict(state_dict, strict=False)更改为model.load_state_dict(state_dict),则会出现以下错误:

Traceback (most recent call last):
File "onnxconvert.py", line 45, in <module>
  model.load_state_dict(state_dict)
File "/Users/myUserName/anaconda2/lib/python2.7/site-   packages/torch/nn/modules/module.py", line 769, in load_state_dict
  self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for TempModel:
Unexpected key(s) in state_dict: "fc.weight", "fc.bias",    "head_0.conv_0.bias", "head_0.conv_0.weight_orig", "head_0.conv_0.weight_u", "head_0.conv_0.weight_v", "head_0.conv_1.bias", "head_0.conv_1.weight_orig", "head_0.conv_1.weight_u", "head_0.conv_1.weight_v"...

有100多个意外的键,我只是将其中一些剪掉了。

似乎我应该在forward中实现TempModel方法,但是该模型有100多个参数,但是我没有创建它,所以我不确定如何精确地做到这一点。 / p>

我该怎么做才能成功加载和导出模型?请帮忙!

0 个答案:

没有答案