带有gridsearch的管道中的多个估计量显示cv_results_中的错误参数

时间:2019-07-14 13:01:26

标签: python pipeline gridsearchcv

我正在将imblearn.pipeline与gridsearchCV一起使用。我正在使用多个估算器,它似乎工作正常。如果我将gridsearchCV.cv_results_打印为DataFrame,则表明所有分类器都相同。但是,当我将列打印到csv时,它的确显示了不同的分类器。下面的代码:

将列打印到csv,它确实显示RandomForrestClassifier

class DummyEstimator(BaseEstimator):
    def fit(self): pass
    def score(self): pass

pipe = Pipeline(steps=[
    ('scale', StandardScaler()),
    ('reduce_dim', PCA(n_components = 'mle')),
    ('sampling', SMOTE()),
    ('classification', DummyEstimator())])

# Add a dict of estimator and estimator related parameters in this list
params_grid = [{
                'classification':[RandomForestClassifier()],
                'classification__max_depth' : [2,3],
                 },
                {
                'classification': [DecisionTreeClassifier()],
                'classification__max_depth': [4,5],
                },]

scorers = {'precision_score': make_scorer(precision_score),
            'recall_score': make_scorer(recall_score),
           'f1_score': make_scorer(f1_score),
           'auprc': make_scorer(average_precision_score)
            }

gridsearch_model = GridSearchCV(pipe, params_grid, cv = 2, scoring = scorers, refit = 'recall_score')
gridsearch_model.fit(X_train, y_train)
results = pd.DataFrame(gridsearch_model.cv_results_)

给出输出:

params  mean_test_recall_score  std_test_recall_score   mean_test_auprc std_test_auprc  mean_test_precision_score   std_test_precision_score    mean_test_f1_score  std_test_f1_score
2   {'classification': DecisionTreeClassifier(clas...   0.963624    0.025077    0.299055    0.052214    0.311904    0.062261    0.466833    0.068199  
1   {'classification': (DecisionTreeClassifier(cla...   0.958037    0.030663    0.342001    0.149949    0.362291    0.168061    0.499737    0.175061  
3   {'classification': DecisionTreeClassifier(clas...   0.943945    0.022157    0.403004    0.130787    0.430331    0.148617    0.573716    0.137491  
0   {'classification': (DecisionTreeClassifier(cla...

but when I access the rows with results['params'][0] it does show RandomForestClassifier(....).

0 个答案:

没有答案