Python scikit-learn:导出训练有素的分类器

时间:2013-07-07 12:00:35

标签: python scikit-learn

我正在使用nolearn基于scikit-learn的DBN(深度信念网络)。

我已经构建了一个可以很好地对我的数据进行分类的网络,现在我有兴趣导出模型进行部署,但我不知道如何(我每次想要预测某些内容时都在训练DBN) )。在matlab中,我只需导出权重矩阵并将其导入另一台机器。

有人知道如何导出要导入的模型/权重矩阵而无需再次训练整个模型吗?

3 个答案:

答案 0 :(得分:61)

您可以使用:

>>> from sklearn.externals import joblib
>>> joblib.dump(clf, 'my_model.pkl', compress=9)

然后,在预测服务器上:

>>> from sklearn.externals import joblib
>>> model_clone = joblib.load('my_model.pkl')

这基本上是一个Python pickle,具有针对大型numpy数组的优化处理。它与常规泡菜w.r.t具有相同的局限性。代码更改:如果pickle对象的类结构发生更改,则可能无法再使用nolearn或scikit-learn的新版本对该对象进行unpickle。

如果你想要长期稳健的存储模型参数的方法,你可能需要编写自己的IO层(例如使用二进制格式的序列化工具,如协议缓冲区或avro或低效但可移植的文本/ json / xml表示,如作为PMML)。

答案 1 :(得分:9)

Pickling / unpickling的缺点是它只适用于匹配的python版本(主要版本,也可能是次要版本)和sklearn,joblib库版本。

机器学习模型还有其他描述性输出格式,例如Data Mining Group开发的格式,例如预测模型标记语言(PMML)和可移植分析格式(PFA)。在这两者中,PMML是much better supported

因此,您可以选择将模型从scikit-learn保存到PMML(例如使用sklearn2pmml),然后使用jpmml在(java,spark或hive)中部署和运行它。当然你有更多的选择)。

答案 2 :(得分:3)

scikit-learn文档中的3.4. Model persistence部分几乎涵盖了所有内容。

除了sklearn.externals.joblib ogrisel指出,它还显示了如何使用常规泡菜包:

>>> from sklearn import svm
>>> from sklearn import datasets
>>> clf = svm.SVC()
>>> iris = datasets.load_iris()
>>> X, y = iris.data, iris.target
>>> clf.fit(X, y)  
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0,
  kernel='rbf', max_iter=-1, probability=False, random_state=None,
  shrinking=True, tol=0.001, verbose=False)

>>> import pickle
>>> s = pickle.dumps(clf)
>>> clf2 = pickle.loads(s)
>>> clf2.predict(X[0])
array([0])
>>> y[0]
0

并提供一些警告,例如在一个版本的scikit-learn中保存的模型可能无法在另一个版本中加载。