将LSTM Pytorch模型转换为ONNX时遇到问题

时间:2019-07-31 22:38:24

标签: python deep-learning pytorch lstm onnx

我正在尝试将我的LSTM异常检测Pytorch模型导出到ONNX,但是遇到错误。请在下面查看我的代码。

注意:我的数据的格式为[2685,5,6]。 这是我定义模型的地方:

class Model(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim):
        super(Model, self).__init__()
        self.hidden_dim = hidden_dim 
        self.layer_dim = layer_dim
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
        self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, input_dim)   
    def forward(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
        out = self.fc1(out) 
        out = self.fc2(out)
        return out

input_dim = 6
hidden_dim = 3
layer_dim = 2
model = Model(input_dim, hidden_dim, layer_dim)

我可以训练它并对其进行良好的测试。但是问题出在导出时:

model.eval()
import torch.onnx
torch_out = torch.onnx.export(model, 
                         torch.randn(2685, 5, 6), 
                         "onnx_model.onnx", 
                         export_params = True
                        )

但是我有以下错误:

LSTM(6, 3, num_layers=2, batch_first=True)
Linear(in_features=3, out_features=3, bias=True)
Linear(in_features=3, out_features=6, bias=True)
['input_1', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear']

/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/symbolic.py:173: UserWarning: ONNX export failed on RNN/GRU/LSTM because batch_first not supported
  warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-264-28c6c55537ab> in <module>()
     10                          torch.randn(2685, 5, 6),
     11                          "onnx_model.onnx",
---> 12                          export_params = True
     13                         )

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py in export(*args, **kwargs)
     23 def export(*args, **kwargs):
     24     from torch.onnx import utils
---> 25     return utils.export(*args, **kwargs)
     26 
     27 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    129             operator_export_type=operator_export_type, opset_version=opset_version,
    130             _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
--> 131             strip_doc_string=strip_doc_string)
    132 
    133 

~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
    367         if export_params:
    368             proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type,
--> 369                                                    strip_doc_string)
    370         else:
    371             proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type, strip_doc_string)

RuntimeError: ONNX export failed: Couldn't export operator aten::lstm

Defined at:
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(522): forward_impl
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(539): forward_tensor
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(559): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
<ipython-input-255-468cef410a2c>(14): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(294): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(493): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(231): get_trace_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(225): _trace_and_get_graph_from_model
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(266): _model_to_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(363): _export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(131): export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py(25): export
<ipython-input-264-28c6c55537ab>(12): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2963): run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2903): run_ast_nodes
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2785): _run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2662): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/zmqshell.py(537): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/ipkernel.py(208): do_execute
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(399): execute_request
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(233): dispatch_shell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(283): dispatcher
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(432): _run_callback
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(480): _handle_recv
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(450): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(117): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/events.py(145): _run
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(1432): _run_once
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(422): run_forever
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(127): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelapp.py(486): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/traitlets/config/application.py(658): launch_instance
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/__main__.py(3): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(85): _run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(193): _run_module_as_main


Graph we tried to export:
graph(%input.1 : Float(2685, 5, 6),
      %lstm.weight_ih_l0 : Float(12, 6),
      %lstm.weight_hh_l0 : Float(12, 3),
      %lstm.bias_ih_l0 : Float(12),
      %lstm.bias_hh_l0 : Float(12),
      %lstm.weight_ih_l1 : Float(12, 3),
      %lstm.weight_hh_l1 : Float(12, 3),
      %lstm.bias_ih_l1 : Float(12),
      %lstm.bias_hh_l1 : Float(12),
      %fc1.weight : Float(3, 3),
      %fc1.bias : Float(3),
      %fc2.weight : Float(6, 3),
      %fc2.bias : Float(6)):
  %13 : Long() = onnx::Constant[value={0}](), scope: Model
  %14 : Tensor = onnx::Shape(%input.1), scope: Model
  %15 : Long() = onnx::Gather[axis=0](%14, %13), scope: Model
  %16 : Long() = onnx::Constant[value={2}](), scope: Model
  %17 : Long() = onnx::Constant[value={3}](), scope: Model
  %18 : Tensor = onnx::Unsqueeze[axes=[0]](%16)
  %19 : Tensor = onnx::Unsqueeze[axes=[0]](%15)
  %20 : Tensor = onnx::Unsqueeze[axes=[0]](%17)
  %21 : Tensor = onnx::Concat[axis=0](%18, %19, %20)
  %22 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%21), scope: Model
  %23 : Long() = onnx::Constant[value={0}](), scope: Model
  %24 : Tensor = onnx::Shape(%input.1), scope: Model
  %25 : Long() = onnx::Gather[axis=0](%24, %23), scope: Model
  %26 : Long() = onnx::Constant[value={2}](), scope: Model
  %27 : Long() = onnx::Constant[value={3}](), scope: Model
  %28 : Tensor = onnx::Unsqueeze[axes=[0]](%26)
  %29 : Tensor = onnx::Unsqueeze[axes=[0]](%25)
  %30 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
  %31 : Tensor = onnx::Concat[axis=0](%28, %29, %30)
  %32 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%31), scope: Model
  %33 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
  %34 : Long() = onnx::Constant[value={2}](), scope: Model/LSTM[lstm]
  %35 : Double() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
  %36 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
  %37 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
  %38 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
  %input.2 : Float(2685!, 5!, 3), %40 : Float(2, 2685, 3), %41 : Float(2, 2685, 3) = aten::lstm(%input.1, %22, %32, %lstm.weight_ih_l0, %lstm.weight_hh_l0, %lstm.bias_ih_l0, %lstm.bias_hh_l0, %lstm.weight_ih_l1, %lstm.weight_hh_l1, %lstm.bias_ih_l1, %lstm.bias_hh_l1, %33, %34, %35, %36, %37, %38), scope: Model/LSTM[lstm]
  %42 : Float(3!, 3!) = onnx::Transpose[perm=[1, 0]](%fc1.weight), scope: Model/Linear[fc1]
  %43 : Float(2685, 5, 3) = onnx::MatMul(%input.2, %42), scope: Model/Linear[fc1]
  %44 : Float(2685, 5, 3) = onnx::Add(%43, %fc1.bias), scope: Model/Linear[fc1]
  %45 : Float(3!, 6!) = onnx::Transpose[perm=[1, 0]](%fc2.weight), scope: Model/Linear[fc2]
  %46 : Float(2685, 5, 6) = onnx::MatMul(%44, %45), scope: Model/Linear[fc2]
  %47 : Float(2685, 5, 6) = onnx::Add(%46, %fc2.bias), scope: Model/Linear[fc2]
  return (%47)

这是什么意思?我应该怎么做才能正确导出?

3 个答案:

答案 0 :(得分:3)

如果您来自Google,以前的答案将不再是最新的。 ONNX现在支持LSTM operator。小心,因为除非您使用dynamic_axes参数,否则从PyTorch导出将默认固定输入序列的长度。

下面是我根据torch.onnx FAQ

改编而成的最小LSTM导出示例
import torch
import onnx
from torch import nn
import numpy as np
import onnxruntime.backend as backend
import numpy as np

torch.manual_seed(0)

layer_count = 4

model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
model.eval()

with torch.no_grad():
    input = torch.randn(1, 3, 10)
    h0 = torch.randn(layer_count * 2, 3, 20)
    c0 = torch.randn(layer_count * 2, 3, 20)
    output, (hn, cn) = model(input, (h0, c0))

    # default export
    torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
    onnx_model = onnx.load('lstm.onnx')
    # input shape [5, 3, 10]
    print(onnx_model.graph.input[0])

    # export with `dynamic_axes`
    torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
                    input_names=['input', 'h0', 'c0'],
                    output_names=['output', 'hn', 'cn'],
                    dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
    onnx_model = onnx.load('lstm.onnx')
    # input shape ['sequence', 3, 10]

# Check export
y, (hn, cn) = model(input, (h0, c0))
y_onnx, hn_onnx, cn_onnx = backend.run(
    onnx_model, 
    [input.numpy(), h0.numpy(), c0.numpy()],
    device='CPU'
)

np.testing.assert_almost_equal(y_onnx, y.detach(), decimal=5)
np.testing.assert_almost_equal(hn_onnx, hn.detach(), decimal=5)
np.testing.assert_almost_equal(cn_onnx, cn.detach(), decimal=5)

我已经用以下方法测试了此示例: 火炬== 1.4.0, onnx = 1.7.0

答案 1 :(得分:0)

您没有做错任何事情

  

RuntimeError:ONNX导出失败:无法导出运算符aten :: lstm

LSTM不在onnx limitations上受支持的运算符列表中

正在检查RuntimError on unsupported aten::的github发布队列,还有(尚)不支持的更多类型。

答案 2 :(得分:0)

尝试使用batch_first = False。 ONNX不支持True。您可能需要转置数据,因为您将拥有:(时间步长,批处理,功能)而不是(批次,时间步长,功能)。