将python随机林模型保存到文件

时间:2013-12-18 15:25:50

标签: python machine-learning scikit-learn random-forest

在R中,跑完"随机森林"我可以使用save.image("***.RData")来存储模型。之后,我可以加载模型直接进行预测。

你能在python中做类似的事情吗?我将模型和预测分成两个文件。在模型文件中:

rf= RandomForestRegressor(n_estimators=250, max_features=9,compute_importances=True)
fit= rf.fit(Predx, Predy)

我尝试返回rffit,但仍然无法在预测文件中加载模型。

您可以使用sklearn随机森林包分离模型和预测吗?

5 个答案:

答案 0 :(得分:23)

...
import cPickle

rf = RandomForestRegresor()
rf.fit(X, y)

with open('path/to/file', 'wb') as f:
    cPickle.dump(rf, f)


# in your prediction file                                                                                                                                                                                                           

with open('path/to/file', 'rb') as f:
    rf = cPickle.load(f)


preds = rf.predict(new_X)

答案 1 :(得分:2)

您可以使用joblib从scikit-learn(实际上是scikit-learn的任何模型)中保存和加载随机森林

示例:

import joblib
from sklearn.ensemble import RandomForestClassifier
# create RF
rf = RandomForestClassifier()
# fit on some data
rf.fit(X, y)

# save
joblib.dump(rf, "my_random_forest.joblib")

# load
loaded_rf = joblib.load("my_random_forest.joblib")

还有joblib.dump has compress参数,因此可以压缩模型。我在虹膜数据集上做了非常简单的test,并且compress=3缩小了文件大小约5.6倍。

答案 2 :(得分:1)

我使用dill,它存储所有数据,我认为可能是模块信息?也许不吧。我记得尝试使用pickle存储这些非常复杂的对象,它对我不起作用。 cPickle可能与dill完成相同的工作,但我从未尝试cpickle。它看起来像字面上完全相同的方式。我使用“obj”扩展名,但这绝不是传统的...因为我存储了一个对象,所以对我来说才有意义。

import dill
wd = "/whatever/you/want/your/working/directory/to/be/"
rf= RandomForestRegressor(n_estimators=250, max_features=9,compute_importances=True)
rf.fit(Predx, Predy)
dill.dump(rf, open(wd + "filename.obj","wb"))
顺便说一句,不确定你是否使用iPython,但有时候写文件并不是这样你必须这样做:

with open(wd + "filename.obj","wb") as f:
    dill.dump(rf,f)

再次调用对象:

model = dill.load(open(wd + "filename.obj","rb"))

答案 3 :(得分:0)

对于模型存储,您也可以使用.sav formate。它存储完整的模型和信息。

答案 4 :(得分:0)

我要重申joblib做得很好,it provides really good compression options(即lzma)。

with open("clf.pkl", "wb") as out: pickle.dump(clf, out)
with open("clf.dill", "wb") as out: dill.dump(clf, out)
joblib.dump(clf, "clf.jbl")
joblib.dump(clf, "clf.jbl.lzma")
joblib.dump(clf, "clf.jbl.gz")

!du clf.*
24576   clf.dill
24576   clf.jbl
5120    clf.jbl.gz
3072    clf.jbl.lzma
24576   clf.pkl