numpy和scikit-learn导致C ++应用程序崩溃

时间:2018-08-22 22:12:57

标签: python c++ numpy scikit-learn

我正在构建一个应用程序,需要在其中嵌入带有C ++代码的Python脚本。 C ++代码是调用者。 Python代码(pyFile.py)使用scikit-learn和numpy实现了机器学习模型。

在我的用例中,我需要在一次C ++程序运行期间多次调用python代码。 python脚本运行平稳,并在C ++首次调用它时吐出结果。但是,当第二次调用时,整个应用程序崩溃。

以下是一个MWE,可为您提供一个想法:

C ++(呼叫者代码):

int callPlugin() {
  auto modelPath = "mlTrainerAndPredictors/get_predictions.py";
  FILE* fp = fopen(modelPath, "r");
  if (fp == NULL) {
    perror("Couldn't open file");
    return -1;
  }
  PyRun_SimpleFile(fp, modelPath);
  fclose(fp);

  return 0;
}

int getPrediction() {
  Py_Initialize();
  CPyObject sysPath = PySys_GetObject((char*)"path");
  CPyObject curDir = PyBytes_FromString("mlTrainerAndPredictors");
  PyList_Append(sysPath, curDir);

  if (callPlugin() == -1)
    return -1;
  printResult();
  Py_Finalize();

  return 0;
}

int main() {
  for (int i = 0; i < 2; ++i)
    getPrediction() // this runs the first time, fails after that
  return 0;
}

Python(正在调用ML模型):

from sklearn.externals import joblib
import numpy as np
import json
import os

# load the classifier from file
clf = joblib.load('mlClassifiers/clf.pkl')

# load test data from file
test_data = []
with open('mlTrainerAndPredictors/test_data_file.json', 'r') as test_data_file:
    tmp = json.load(test_data_file)
    for data in tmp['test_data']:
        test_data.append(data)

# convert data to numpy array
numpy_array = np.asarray(test_data)

# predict using saved classifier
predictions = np.array(clf.predict(numpy_array), dtype=np.float64)
print "Predictions:"
print predictions

test_data_file.json

{"test_data": [
  [ 0.0, 0.0,  5.0, 13.0,  9.0,  1.0, 0.0, 0.0,
    0.0, 0.0, 13.0, 15.0, 10.0, 15.0, 5.0, 0.0,
    0.0, 3.0, 15.0,  2.0,  0.0, 11.0, 8.0, 0.0,
    0.0, 4.0, 12.0,  0.0,  0.0,  8.0, 8.0, 0.0,
    0.0, 5.0,  8.0,  0.0,  0.0,  9.0, 8.0, 0.0,
    0.0, 4.0, 11.0,  0.0,  1.0, 12.0, 7.0, 0.0,
    0.0, 2.0, 14.0,  5.0, 10.0, 12.0, 0.0, 0.0,
    0.0, 0.0,  6.0, 13.0, 10.0,  0.0, 0.0, 0.0 ],
  [ 0.0, 0.0,  0.0, 12.0, 13.0,  5.0, 0.0, 0.0,
    0.0, 0.0,  0.0, 11.0, 16.0,  9.0, 0.0, 0.0,
    0.0, 0.0,  3.0, 15.0, 16.0,  6.0, 0.0, 0.0,
    0.0, 7.0, 15.0, 16.0, 16.0,  2.0, 0.0, 0.0,
    0.0, 0.0,  1.0, 16.0, 16.0,  3.0, 0.0, 0.0,
    0.0, 0.0,  1.0, 16.0, 16.0,  6.0, 0.0, 0.0,
    0.0, 0.0,  1.0, 16.0, 16.0,  6.0, 0.0, 0.0,
    0.0, 0.0,  0.0, 11.0, 16.0, 10.0, 0.0, 0.0 ]
]
}

0 个答案:

没有答案