最近,我一直在为使用Tensorflow后端的Keras中的超参数调整应用网格搜索交叉验证(sklearn GridSearchCV)。我的模型一经调整 我试图保存GridSearchCV对象供以后使用而没有成功。
超参数调整如下:
x_train, x_val, y_train, y_val = train_test_split(NN_input, NN_target, train_size = 0.85, random_state = 4)
history = History()
kfold = 10
regressor = KerasRegressor(build_fn = create_keras_model, epochs = 100, batch_size=1000, verbose=1)
neurons = np.arange(10,101,10)
hidden_layers = [1,2]
optimizer = ['adam','sgd']
activation = ['relu']
dropout = [0.1]
parameters = dict(neurons = neurons,
hidden_layers = hidden_layers,
optimizer = optimizer,
activation = activation,
dropout = dropout)
gs = GridSearchCV(estimator = regressor,
param_grid = parameters,
scoring='mean_squared_error',
n_jobs = 1,
cv = kfold,
verbose = 3,
return_train_score=True))
grid_result = gs.fit(NN_input,
NN_target,
callbacks=[history],
verbose=1,
validation_data=(x_val, y_val))
备注:create_keras_model函数初始化并编译Keras顺序模型。
执行交叉验证后,我尝试使用以下代码保存网格搜索对象(gs):
from sklearn.externals import joblib
joblib.dump(gs, 'GS_obj.pkl')
我得到的错误如下:
TypeError: can't pickle _thread.RLock objects
能否让我知道此错误的原因是什么?
谢谢!
P.S .: joblib.dump方法对于保存使用的GridSearchCV对象效果很好 用于sklearn的MLPRegressors培训。
答案 0 :(得分:2)
使用
import joblib
直接
代替
from sklearn.externals import joblib
使用以下方法保存对象或结果:
joblib.dump(gs, 'model_file_name.pkl')
并使用以下方法加载结果:
joblib.load("model_file_name.pkl")
这是一个简单的工作示例:
import joblib
#save your model or results
joblib.dump(gs, 'model_file_name.pkl')
#load your model for further usage
joblib.load("model_file_name.pkl")
答案 1 :(得分:0)
尝试一下:
from sklearn.externals import joblib
joblib.dump(gs.best_estimator_, 'filename.pkl')
如果要将对象转储到一个文件中,请使用:
joblib.dump(gs.best_estimator_, 'filename.pkl', compress = 1)
简单示例:
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
from sklearn.externals import joblib
iris = datasets.load_iris()
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
svc = svm.SVC()
gs = GridSearchCV(svc, parameters)
gs.fit(iris.data, iris.target)
joblib.dump(gs.best_estimator_, 'filename.pkl')
#['filename.pkl']
编辑1:
您还可以保存整个对象:
joblib.dump(gs, 'gs_object.pkl')
答案 2 :(得分:0)
为sklearn.model_selection._search.BaseSearchCV
类子类。覆盖fit(self, X, y=None, groups=None, **fit_params)
方法,并修改其内部evaluate_candidates(candidate_params)
函数。而不是立即从results
返回evaluate_candidates(candidate_params)
字典,而是在此处(或根据您的使用情况,在_run_search
方法中)执行序列化。经过一些额外的修改,此方法的另一个好处是允许您顺序执行网格搜索(请参见此处的源代码中的注释:_search.py)。请注意,results
返回的evaluate_candidates(candidate_params)
字典与cv_results
字典相同。这种方法对我有用,但是我也试图为中断的网格搜索执行添加保存和恢复功能。