我正在尝试使用助手API构建onnx图。我开始的最简单的示例如下。一个MatMul运算符,它需要两个[1]矩阵输入(X和W),并产生[1]矩阵输出Y。
import numpy as np
import onnxruntime as rt
from onnx import *
from onnxmltools.utils import save_mode
initializer = []
initializer.append(helper.make_tensor(name="W", data_type=TensorProto.FLOAT, dims=(1,), vals=np.ones(1).tolist()))
graph = helper.make_graph(
[
helper.make_node('MatMul', ["X", "W"], ["Y"]),
],
"TEST",
[
helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('W', TensorProto.FLOAT, [1]),
],
[
helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1]),
],
initializer=initializer,
)
checker.check_graph(graph)
model = helper.make_model(graph, producer_name='TEST')
save_model(model, "model.onnx")
sess = rt.InferenceSession('model.onnx')
当我运行它时,它会这样抱怨:
Traceback (most recent call last):
File "onnxruntime_test.py", line 35, in <module>
sess = rt.InferenceSession('model.onnx')
File "/usr/local/lib/python3.5/dist-packages/onnxruntime/capi/session.py", line 29, in __init__
self._sess.load_model(path_or_bytes)
RuntimeError: [ONNXRuntimeError] : 1 : GENERAL ERROR : Node: Output:Y [ShapeInferenceError] Mismatch between number of source and target dimensions. Source=0 Target=1
我在这里呆了几个小时。有人可以给我任何帮助吗?
答案 0 :(得分:0)
请参见https://github.com/microsoft/onnxruntime/issues/380
我更改了一些位置来使您的代码正常工作。下面是新的
import numpy as np
import onnxruntime as rt
from onnx import *
from onnx import utils
initializer = []
initializer.append(helper.make_tensor(name="W", data_type=TensorProto.FLOAT, dims=(1,), vals=np.ones(1).tolist()))
graph = helper.make_graph(
[
helper.make_node('MatMul', ["X", "W"], ["Y"]),
],
"TEST",
[
helper.make_tensor_value_info('X' , TensorProto.FLOAT, [1]),
helper.make_tensor_value_info('W', TensorProto.FLOAT, [1]),
],
[
helper.make_tensor_value_info('Y', TensorProto.FLOAT, []),
],
initializer=initializer,
)
checker.check_graph(graph)
model = helper.make_model(graph, producer_name='TEST')
final_model = onnx.utils.polish_model(model)
onnx.save(final_model, 'model.onnx')
sess = rt.InferenceSession('model.onnx')
要表示标量,应使用形状“ []”,而不是“ [1]”。