我正在尝试对已打包为ONNX的sklearn管道中的文本进行预测。我可以写出并读取模型,但是当我做出预测时,我得到了错误: “由于以下原因导致方法运行失败:[ONNXRuntimeError]:2:INVALID_ARGUMENT:缺少所需的输入:float_input”。有谁知道如何通过sklearn管道对文本进行预测?
我已经按照本教程http://onnx.ai/sklearn-onnx/auto_examples/plot_tfidfvectorizer.html#sphx-glr-download-auto-examples-plot-tfidfvectorizer-py进行操作,但是无法做出预测。
'''
#convert pipeline into onnx
model_onnx = convert_sklearn(pipeline, "tfidf",
initial_types=[("str_input", StringTensorType([1, 2000]))])
with open("pipeline_emails.onnx", "wb") as f:
f.write(onx.SerializeToString())
#make predictions on test data
sess = rt.InferenceSession("pipeline_emails.onnx")
pred_onx = sess.run(None, {"str_input": test_df.as_matrix()})[0]
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1])
'''
RuntimeError Traceback (most recent call last)
<ipython-input-118-5db056b989a8> in <module>()
2 sess = rt.InferenceSession("pipeline_emails.onnx")
3 inputs = {'str_input': test_df.as_matrix()}
----> 4 pred_onx = sess.run(None, {"str_input": test_df.as_matrix()})[0]
5 print("predict", pred_onx[0])
6 print("predict_proba", pred_onx[1])
~\AppData\Local\Continuum\anaconda3\lib\site-packages\onnxruntime\capi\session.py in run(self, output_names, input_feed, run_options)
70 if not output_names:
71 output_names = [output.name for output in self._outputs_meta]
---> 72 return self._sess.run(output_names, input_feed, run_options)
73
74 def end_profiling(self):
RuntimeError: Method run failed due to: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Missing required input: float_input