在带有嵌套估算器的管道中使用GridSearchCV

时间:2020-09-26 19:13:16

标签: python scikit-learn pipeline random-forest gridsearchcv

我尝试使用管道构建这样的模型:我想使用随机的Forst分类器预测多个输出。由于管道只允许最后一步成为分类器,因此我嵌套了管道。在没有GridSearch的情况下,该方法可以正常工作。

pipeline = Pipeline([
('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
('clf', MultiOutputClassifier(RandomForestClassifier(), n_jobs=-1)),
])

现在,我尝试将多个参数传递给我的RF分类器,但是由于它是嵌套的,因此它将传递给MultiOutputClassifier,至少我认为是这样。

param_grid = { 
    'clf__n_estimators': [200, 500],
    'clf__max_features': ['auto', 'sqrt', 'log2'],
    'clf__max_depth' : [4,5,6,7,8],
    'clf__criterion' :['gini', 'entropy']
}

cv = GridSearchCV(pipeline, param_grid=param_grid)

这将导致错误:ValueError:估算器的无效参数标准

是否可以将参数传递给我的RandomForestClassifier还是可以通过管道传递多个分类器?

1 个答案:

答案 0 :(得分:1)

尝试一下:

pipeline = Pipeline([
('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
('clf', MultiOutputClassifier(RandomForestClassifier(), n_jobs=-1)),
])

param_grid = { 
    'clf__estimator__n_estimators': [200, 500],
    'clf__estimator__max_features': ['auto', 'sqrt', 'log2'],
    'clf__estimator__max_depth' : [4,5,6,7,8],
    'clf__estimator__criterion' :['gini', 'entropy']
}

cv = GridSearchCV(pipeline, param_grid=param_grid, n_jobs=2)

通常,您可以通过以下方式访问可调参数:

cv.get_params().keys()
dict_keys(['cv', 'error_score', 'estimator__memory', 'estimator__steps', 'estimator__verbose', 'estimator__vect', 'estimator__tfidf', 'estimator__clf', 'estimator__vect__analyzer', 'estimator__vect__binary', 'estimator__vect__decode_error', 'estimator__vect__dtype', 'estimator__vect__encoding', 'estimator__vect__input', 'estimator__vect__lowercase', 'estimator__vect__max_df', 'estimator__vect__max_features', 'estimator__vect__min_df', 'estimator__vect__ngram_range', 'estimator__vect__preprocessor', 'estimator__vect__stop_words', 'estimator__vect__strip_accents', 'estimator__vect__token_pattern', 'estimator__vect__tokenizer', 'estimator__vect__vocabulary', 'estimator__tfidf__norm', 'estimator__tfidf__smooth_idf', 'estimator__tfidf__sublinear_tf', 'estimator__tfidf__use_idf', 'estimator__clf__estimator__bootstrap', 'estimator__clf__estimator__ccp_alpha', 'estimator__clf__estimator__class_weight', 'estimator__clf__estimator__criterion', 'estimator__clf__estimator__max_depth', 'estimator__clf__estimator__max_features', 'estimator__clf__estimator__max_leaf_nodes', 'estimator__clf__estimator__max_samples', 'estimator__clf__estimator__min_impurity_decrease', 'estimator__clf__estimator__min_impurity_split', 'estimator__clf__estimator__min_samples_leaf', 'estimator__clf__estimator__min_samples_split', 'estimator__clf__estimator__min_weight_fraction_leaf', 'estimator__clf__estimator__n_estimators', 'estimator__clf__estimator__n_jobs', 'estimator__clf__estimator__oob_score', 'estimator__clf__estimator__random_state', 'estimator__clf__estimator__verbose', 'estimator__clf__estimator__warm_start', 'estimator__clf__estimator', 'estimator__clf__n_jobs', 'estimator', 'iid', 'n_jobs', 'param_grid', 'pre_dispatch', 'refit', 'return_train_score', 'scoring', 'verbose'])