如何保存XGboost base_learners?

时间:2019-01-30 01:05:46

标签: python multithreading xgboost

我正在使用XGBoost算法学习集成模型

当我打印base_learners时,它似乎被存储为字典类型。 像这样:

{'dnn': <keras.engine.sequential.Sequential object at 0x000001BB50B97C88>, 'random forest': RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=4, max_features='sqrt', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=2, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=-1,
            oob_score=False, random_state=42, verbose=0, warm_start=False), 'extra trees': ExtraTreesClassifier(bootstrap=False, class_weight=None, criterion='gini',
           max_depth=4, max_features='auto', max_leaf_nodes=None,
           min_impurity_decrease=0.0, min_impurity_split=None,
           min_samples_leaf=2, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=-1,
           oob_score=False, random_state=42, verbose=0, warm_start=False)}

要在另一个文件中使用“ base_learner”,如何保存? 我不能使用save_model()。因为那不是模型

而且,我也不能使用pickle模块。我不知道为什么。

但是我认为多线程错误问题。

当我使用pickle模块时,出现以下错误消息:

pickle.dump(base_learners, open('./models/base_learners.pkl', 'wb'))
TypeError: can't pickle _thread.RLock objects

如何解决此问题?

1 个答案:

答案 0 :(得分:0)

尝试使用sckit学习库。

from sklearn.externals import joblib  
joblib.dump(test_model, "trained-model.pkl")

从文件中加载经过训练的模型

test_model = joblib.load("trained-model.pkl")

检查是否有效。这里test_model的值为

LogisticRegressionCV(Cs=3, class_weight='balanced', cv=10, dual=False,
           fit_intercept=True, intercept_scaling=1.0, max_iter=100,
           multi_class='ovr', n_jobs=-1, penalty='l2', random_state=42,
           refit=False, scoring=None, solver='lbfgs', tol=0.0001,
           verbose=0)

现在,如果它对您不起作用,那么您需要取出关键值“ random forest”和“ extra trees”并将其另存为pickle文件。我只是尝试将此值与我的值进行比较,两者看起来都相似。一次,您可以重新加载这两个值,也可以重新创建原始的base_learners。