使用python onnxruntime进行预测时出错

时间:2019-11-26 09:08:24

标签: python scikit-learn sklearn-pandas onnx onnxruntime

我使用sklearn库创建了一个非常基本的决策树。该树是根据以下4个特征进行训练的:

feat1 INT
feat2 INT
feat3 FLOAT
feat4 FLOAT

标签/目标特征是一个布尔值(0或1)。

我将树转换为ONNX格式,现在我想使用onnxruntime python库进行预测。我已经在互联网上找到示例代码来做到这一点。问题是我不完全了解这段代码,函数和参数的所有部分到底发生了什么。这导致我得到一个错误。我确实搜索了一些文档,但是找不到。

在下面的代码中,我将树模型转换为ONNX格式。这是成功的,但是我不理解的部分代码。在initial_type变量中,根据我之前提到的4个功能列和标签/目标功能,我必须在此处输入什么?现在我进入FloatTensorType([None, 4]是因为我有4个功能列,而None我不知道是什么。

##Convert to ONNX format

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(treeModel, initial_types=initial_type)
with open("path", "wb") as f:
    f.write(onx.SerializeToString())

在下面的代码中,我想使用onnxruntime库进行预测,但出现此错误:

RuntimeError: Either type_proto was null or it was not of sequence type

这是因为我不理解下面的代码的最后一行。我输入此{input_name: [4, 8, 77.8, 143.45]是因为这是功能列的四个值。我在这里做什么错了?

sess = rt.InferenceSession("pathToONNXModel")
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: [4, 8, 77.8, 143.45]})[0]

1 个答案:

答案 0 :(得分:1)

您尝试过{input_name: numpy.array([4, 8, 77.8, 143.45], dtype=numpy.float32)}吗? onnxruntime需要numpy数组作为输入。