跳过GridSearchCV中的管道步骤

时间:2020-07-17 10:56:07

标签: python scikit-learn pipeline grid-search

我创建了以下管道:

 pipeline_numeric_feat= Pipeline([('mispelling',adhoc_transf.misspellingTransformer()),
                                 ('features_cast',adhoc_transf.Numeric_Cast_Column()),
                                 ('data_missing',missing_val_imput.Numeric_Imputer(strategy='median')),
                                 ('features_select',feature_select.Feature_Selector(strategy='wrapper_RFECV')),
                                 ('scaler', MinMaxScaler())
                        ])

pipeline_category_feat= Pipeline([('mispelling',adhoc_transf.misspellingTransformer()),
                                 ('features_cast',adhoc_transf.Category_Cast_Column()),
                                 ('data_missing',missing_val_imput.Category_Imputer(strategy='most_frequent')),
                                 ('cat_feat_engineering',adhoc_transf.CastDown()),
                                 ('encoding', OrdinalEncoder()),
                                 ('features_select',feature_select.Feature_Selector(strategy='wrapper_RFECV'))
                        ])

dataprep_pipe=ColumnTransformer([('numeric_pipe',pipeline_numeric_feat,numerical_features),
                                 ('category_pipe',pipeline_category_feat, category_features)
                                ])

full_pipeline=Pipeline([('data_prep',dataprep_pipe),
                        ('model',RandomForestClassifier()])

我想执行一个GridSearchCV,可以选择跳过pipeline_numeric_feat和pipeline_category_feat中的“ feature_select”步骤。为此,我的param_grid是:

param_grid_v4={'model': [SGDClassifier(),LogisticRegression(),LinearSVC(),SVC(),DecisionTreeClassifier(),RandomForestClassifier()],
            'data_prep__numeric_pipe__features_select':['passthrough',feature_select.Feature_Selector()],
            'data_prep__categroy_pipe__features_select':['passthrough',feature_select.Feature_Selector()]
    }

但是在运行时出现以下错误:Invalid parameter categroy_pipe for estimator.

clf_v4=GridSearchCV(full_pipeline,param_grid_v4cv=5,n_jobs=-1)
clf_v4.fit(X_train,y_train)

我知道问题出在param_grid的以下行

'data_prep__categroy_pipe__features_select':['passthrough',feature_select.Feature_Selector()],

特别是这段代码feature_select.Feature_Selector()

因此,我如何定义param_grid中的'data_prep__categroy_pipe__features_select'步骤有时会被跳过而其他时候不会被跳过。

0 个答案:

没有答案