在cross_validate之后如何导出/保存拟合模型并稍后在熊猫上使用

时间:2019-07-05 13:46:36

标签: python pandas scikit-learn random-forest

我正在使用cross_validate sklearn函数来拟合RandomForest分类器。 我想知道是否有一种方法可以导出拟合的模型以保存它们,并可以导入以预测新数据。

我尝试使用return_estimator=True选项

  

[return_estimator:布尔值,默认为False是否返回   估计值适合每个分割。]

,然后joblib保存估算器。但是,当我加载保存的模型并尝试将其用于predict时,出现了错误,(请参见下文)。

rfc = RandomForestClassifier(n_estimators=100)
cv_results = cross_validate(rfc, X_train_std ,Y_train, scoring=scoring, cv=5, return_estimator=True)
rfc_fit = cv_results['estimator']

#save estimated model
savedir = ('C://Users//.......//src//US//') 

from sklearn.externals import joblib
filename = os.path.join(savedir, 'final_model.joblib')
joblib.dump(rfc_fit,filename)

rfc_model2 = joblib.load(filename)
bla = rfc_model2.predict(X_test_std)

AttributeError: 'tuple' object has no attribute 'predict'

我想我对return_estimator的真正回报感到困惑。 看起来他们不是合适的模特。那么,有没有办法在交叉验证期间提取适合的模型以重新使用它们?

谢谢

1 个答案:

答案 0 :(得分:0)

return_estimator返回所有拟合模型的“元组”。

要解决此问题,您需要选择所需的模型,将其保存,加载并进行预测。

示例:

from sklearn import datasets, linear_model
from sklearn.model_selection import cross_validate

diabetes = datasets.load_diabetes()
X = diabetes.data[:150]
y = diabetes.target[:150]
lasso = linear_model.Lasso()

cv_results = cross_validate(lasso, X, y, cv=3, return_estimator=True)
rfc_fit = cv_results['estimator']
print(rfc_fit)

上面打印了3个模型:

  

(套索(alpha = 1.0,copy_X = True,fit_intercept = True,max_iter = 1000,
  normalize = False,正= False,precompute = False,random_state =无,   selection ='循环',tol = 0.0001,warm_start = False),Lasso(alpha = 1.0,   copy_X = True,fit_intercept = True,max_iter = 1000,normalize = False,   正=假,预计算=假,random_state =无,
  selection ='循环',tol = 0.0001,warm_start = False),Lasso(alpha = 1.0,   copy_X = True,fit_intercept = True,max_iter = 1000,normalize = False,   正=假,预计算=假,random_state =无,
  selection ='cyclic',tol = 0.0001,warm_start = False))

要查看包含的模型数量,请执行以下操作:

print(len(rfc_fit))
# 3

假设您要选择第一个模型:

# select the first model
rfc_fit = rfc_fit[0]

# save it
from sklearn.externals import joblib
filename = os.path.join(savedir, 'final_model.joblib')
joblib.dump(rfc_fit,filename)

# load it
rfc_model2 = joblib.load(filename)

Predict现在可以正常工作:

predicted = rfc_model2.predict(X)