我正在尝试将我的LSTM异常检测Pytorch模型导出到ONNX,但是遇到错误。请在下面查看我的代码。
注意:我的数据的格式为[2685,5,6]。 这是我定义模型的地方:
class Model(torch.nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim):
super(Model, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)
self.fc1 = torch.nn.Linear(hidden_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, input_dim)
def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
out = self.fc1(out)
out = self.fc2(out)
return out
input_dim = 6
hidden_dim = 3
layer_dim = 2
model = Model(input_dim, hidden_dim, layer_dim)
我可以训练它并对其进行良好的测试。但是问题出在导出时:
model.eval()
import torch.onnx
torch_out = torch.onnx.export(model,
torch.randn(2685, 5, 6),
"onnx_model.onnx",
export_params = True
)
但是我有以下错误:
LSTM(6, 3, num_layers=2, batch_first=True)
Linear(in_features=3, out_features=3, bias=True)
Linear(in_features=3, out_features=6, bias=True)
['input_1', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear', 'l_lstm_LSTM', 'l_fc1_Linear', 'l_fc2_Linear']
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/symbolic.py:173: UserWarning: ONNX export failed on RNN/GRU/LSTM because batch_first not supported
warnings.warn("ONNX export failed on " + op + " because " + msg + " not supported")
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-264-28c6c55537ab> in <module>()
10 torch.randn(2685, 5, 6),
11 "onnx_model.onnx",
---> 12 export_params = True
13 )
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py in export(*args, **kwargs)
23 def export(*args, **kwargs):
24 from torch.onnx import utils
---> 25 return utils.export(*args, **kwargs)
26
27
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, aten, export_raw_ir, operator_export_type, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
129 operator_export_type=operator_export_type, opset_version=opset_version,
130 _retain_param_name=_retain_param_name, do_constant_folding=do_constant_folding,
--> 131 strip_doc_string=strip_doc_string)
132
133
~/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, propagate, opset_version, _retain_param_name, do_constant_folding, strip_doc_string)
367 if export_params:
368 proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type,
--> 369 strip_doc_string)
370 else:
371 proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type, strip_doc_string)
RuntimeError: ONNX export failed: Couldn't export operator aten::lstm
Defined at:
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(522): forward_impl
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(539): forward_tensor
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/rnn.py(559): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
<ipython-input-255-468cef410a2c>(14): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(481): _slow_forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(491): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(294): forward
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/nn/modules/module.py(493): __call__
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/jit/__init__.py(231): get_trace_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(225): _trace_and_get_graph_from_model
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(266): _model_to_graph
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(363): _export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/utils.py(131): export
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/torch/onnx/__init__.py(25): export
<ipython-input-264-28c6c55537ab>(12): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2963): run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2903): run_ast_nodes
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2785): _run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/IPython/core/interactiveshell.py(2662): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/zmqshell.py(537): run_cell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/ipkernel.py(208): do_execute
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(399): execute_request
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(233): dispatch_shell
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelbase.py(283): dispatcher
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(432): _run_callback
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(480): _handle_recv
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/zmq/eventloop/zmqstream.py(450): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/stack_context.py(276): null_wrapper
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(117): _handle_events
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/events.py(145): _run
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(1432): _run_once
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/asyncio/base_events.py(422): run_forever
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/tornado/platform/asyncio.py(127): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/kernelapp.py(486): start
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/traitlets/config/application.py(658): launch_instance
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/ipykernel/__main__.py(3): <module>
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(85): _run_code
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/runpy.py(193): _run_module_as_main
Graph we tried to export:
graph(%input.1 : Float(2685, 5, 6),
%lstm.weight_ih_l0 : Float(12, 6),
%lstm.weight_hh_l0 : Float(12, 3),
%lstm.bias_ih_l0 : Float(12),
%lstm.bias_hh_l0 : Float(12),
%lstm.weight_ih_l1 : Float(12, 3),
%lstm.weight_hh_l1 : Float(12, 3),
%lstm.bias_ih_l1 : Float(12),
%lstm.bias_hh_l1 : Float(12),
%fc1.weight : Float(3, 3),
%fc1.bias : Float(3),
%fc2.weight : Float(6, 3),
%fc2.bias : Float(6)):
%13 : Long() = onnx::Constant[value={0}](), scope: Model
%14 : Tensor = onnx::Shape(%input.1), scope: Model
%15 : Long() = onnx::Gather[axis=0](%14, %13), scope: Model
%16 : Long() = onnx::Constant[value={2}](), scope: Model
%17 : Long() = onnx::Constant[value={3}](), scope: Model
%18 : Tensor = onnx::Unsqueeze[axes=[0]](%16)
%19 : Tensor = onnx::Unsqueeze[axes=[0]](%15)
%20 : Tensor = onnx::Unsqueeze[axes=[0]](%17)
%21 : Tensor = onnx::Concat[axis=0](%18, %19, %20)
%22 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%21), scope: Model
%23 : Long() = onnx::Constant[value={0}](), scope: Model
%24 : Tensor = onnx::Shape(%input.1), scope: Model
%25 : Long() = onnx::Gather[axis=0](%24, %23), scope: Model
%26 : Long() = onnx::Constant[value={2}](), scope: Model
%27 : Long() = onnx::Constant[value={3}](), scope: Model
%28 : Tensor = onnx::Unsqueeze[axes=[0]](%26)
%29 : Tensor = onnx::Unsqueeze[axes=[0]](%25)
%30 : Tensor = onnx::Unsqueeze[axes=[0]](%27)
%31 : Tensor = onnx::Concat[axis=0](%28, %29, %30)
%32 : Float(2, 2685, 3) = onnx::ConstantOfShape[value={0}](%31), scope: Model
%33 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
%34 : Long() = onnx::Constant[value={2}](), scope: Model/LSTM[lstm]
%35 : Double() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
%36 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
%37 : Long() = onnx::Constant[value={0}](), scope: Model/LSTM[lstm]
%38 : Long() = onnx::Constant[value={1}](), scope: Model/LSTM[lstm]
%input.2 : Float(2685!, 5!, 3), %40 : Float(2, 2685, 3), %41 : Float(2, 2685, 3) = aten::lstm(%input.1, %22, %32, %lstm.weight_ih_l0, %lstm.weight_hh_l0, %lstm.bias_ih_l0, %lstm.bias_hh_l0, %lstm.weight_ih_l1, %lstm.weight_hh_l1, %lstm.bias_ih_l1, %lstm.bias_hh_l1, %33, %34, %35, %36, %37, %38), scope: Model/LSTM[lstm]
%42 : Float(3!, 3!) = onnx::Transpose[perm=[1, 0]](%fc1.weight), scope: Model/Linear[fc1]
%43 : Float(2685, 5, 3) = onnx::MatMul(%input.2, %42), scope: Model/Linear[fc1]
%44 : Float(2685, 5, 3) = onnx::Add(%43, %fc1.bias), scope: Model/Linear[fc1]
%45 : Float(3!, 6!) = onnx::Transpose[perm=[1, 0]](%fc2.weight), scope: Model/Linear[fc2]
%46 : Float(2685, 5, 6) = onnx::MatMul(%44, %45), scope: Model/Linear[fc2]
%47 : Float(2685, 5, 6) = onnx::Add(%46, %fc2.bias), scope: Model/Linear[fc2]
return (%47)
这是什么意思?我应该怎么做才能正确导出?
答案 0 :(得分:3)
如果您来自Google,以前的答案将不再是最新的。 ONNX现在支持LSTM operator。小心,因为除非您使用dynamic_axes
参数,否则从PyTorch导出将默认固定输入序列的长度。
下面是我根据torch.onnx FAQ
改编而成的最小LSTM导出示例import torch
import onnx
from torch import nn
import numpy as np
import onnxruntime.backend as backend
import numpy as np
torch.manual_seed(0)
layer_count = 4
model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
model.eval()
with torch.no_grad():
input = torch.randn(1, 3, 10)
h0 = torch.randn(layer_count * 2, 3, 20)
c0 = torch.randn(layer_count * 2, 3, 20)
output, (hn, cn) = model(input, (h0, c0))
# default export
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
onnx_model = onnx.load('lstm.onnx')
# input shape [5, 3, 10]
print(onnx_model.graph.input[0])
# export with `dynamic_axes`
torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
input_names=['input', 'h0', 'c0'],
output_names=['output', 'hn', 'cn'],
dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
onnx_model = onnx.load('lstm.onnx')
# input shape ['sequence', 3, 10]
# Check export
y, (hn, cn) = model(input, (h0, c0))
y_onnx, hn_onnx, cn_onnx = backend.run(
onnx_model,
[input.numpy(), h0.numpy(), c0.numpy()],
device='CPU'
)
np.testing.assert_almost_equal(y_onnx, y.detach(), decimal=5)
np.testing.assert_almost_equal(hn_onnx, hn.detach(), decimal=5)
np.testing.assert_almost_equal(cn_onnx, cn.detach(), decimal=5)
我已经用以下方法测试了此示例: 火炬== 1.4.0, onnx = 1.7.0
答案 1 :(得分:0)
您没有做错任何事情
RuntimeError:ONNX导出失败:无法导出运算符aten :: lstm
LSTM不在onnx limitations上受支持的运算符列表中
正在检查RuntimError on unsupported aten::的github发布队列,还有(尚)不支持的更多类型。
答案 2 :(得分:0)
尝试使用batch_first = False。 ONNX不支持True。您可能需要转置数据,因为您将拥有:(时间步长,批处理,功能)而不是(批次,时间步长,功能)。