将Pytorch模型.pth转换为onnx模型

时间:2018-04-24 16:51:30

标签: tensorflow deep-learning pytorch

我有一个预先训练过的模型,扩展为.pth格式。我想将其转换为Tensorflow protobuf。但我找不到任何办法。我已经看到onnx可以将模型从pytorch转换为onnx,然后从onnx转换为Tensorflow。但是通过这种方法,我在转换的第一阶段遇到了错误。

from torch.autograd import Variable
import torch.onnx
import torchvision
import torch 

dummy_input = Variable(torch.randn(1, 3, 256, 256))
model = torch.load('./my_model.pth')
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")`

它给出了这样的错误。

File "t.py", line 9, in <module>
    torch.onnx.export(model, dummy_input, "moment-in-time.onnx")
  File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 75, in export
    _export(model, args, f, export_params, verbose, training)
  File "/usr/local/lib/python3.5/dist-packages/torch/onnx/__init__.py", line 108, in _export
    orig_state_dict_keys = model.state_dict().keys()
AttributeError: 'dict' object has no attribute 'state_dict'

什么是可能的解决方案?

2 个答案:

答案 0 :(得分:1)

这意味着您的模型不是torch.nn.Modules类的子类。如果将其作为子类,则应该可以使用。

答案 1 :(得分:1)

尝试将您的代码更改为此

from torch.autograd import Variable

import torch.onnx
import torchvision
import torch

dummy_input = Variable(torch.randn(1, 3, 256, 256))
state_dict = torch.load('./my_model.pth')
model.load_state_dict(state_dict)
torch.onnx.export(model, dummy_input, "moment-in-time.onnx")