如何在scikit-learn中通过GridSearchCV调整嵌套管道的参数?

时间:2013-05-08 09:24:19

标签: scikit-learn

是否可以在scikit-learn中调整嵌套管道的参数? E.g:

svm = Pipeline([
    ('chi2', SelectKBest(chi2)),
    ('cls', LinearSVC(class_weight='auto'))
])

classifier = Pipeline([
    ('vectorizer', TfIdfVectorizer()),
    ('ova_svm', OneVsRestClassifier(svm))
})

parameters = ?

GridSearchCV(classifier, parameters)

如果无法直接执行此操作,可能是一种解决方法?

2 个答案:

答案 0 :(得分:14)

scikit-learn有一个双重下划线表示法,as exemplified here。它以递归方式工作并延伸到OneVsRestClassifier,但需要注意的是必须将基础估算工具明确地称为__estimator

parameters = {'ova_svm__estimator__cls__C': [1, 10, 100],
              'ova_svm__estimator__chi2_k': [200, 500, 1000]}

答案 1 :(得分:11)

对于您创建的估算工具,您可以按如下方式获取带有标签的参数列表。

import pprint as pp

pp.pprint(sorted(classifier.get_params().keys()))
  

[' ova_svm',' ova_svm__estimator',' ova_svm__estimator__chi2',   ' ova_svm__estimator__chi2__k&#39 ;,   ' ova_svm__estimator__chi2__score_func',' ova_svm__estimator__cls',   ' ova_svm__estimator__cls__C&#39 ;,   ' ova_svm__estimator__cls__class_weight&#39 ;,   ' ova_svm__estimator__cls__dual&#39 ;,   ' ova_svm__estimator__cls__fit_intercept&#39 ;,   ' ova_svm__estimator__cls__intercept_scaling&#39 ;,   ' ova_svm__estimator__cls__loss',' ova_svm__estimator__cls__max_iter',   ' ova_svm__estimator__cls__multi_class&#39 ;,   ' ova_svm__estimator__cls__penalty&#39 ;,   ' ova_svm__estimator__cls__random_state&#39 ;,   ' ova_svm__estimator__cls__tol',' ova_svm__estimator__cls__verbose',   ' ova_svm__estimator__steps',' ova_svm__n_jobs',' steps',   ' vectorizer',' vectorizer__analyzer',' vectorizer__binary',   ' vectorizer__decode_error',' vectorizer__dtype',   ' vectorizer__encoding',' vectorizer__input',   ' vectorizer__lowercase',' vectorizer__max_df',   ' vectorizer__max_features',' vectorizer__min_df',   ' vectorizer__ngram_range',' vectorizer__norm',   ' vectorizer__preprocessor',' vectorizer__smooth_idf',   ' vectorizer__stop_words',' vectorizer__strip_accents',   ' vectorizer__sublinear_tf',' vectorizer__token_pattern',   ' vectorizer__tokenizer',' vectorizer__use_idf',   ' vectorizer__vocabulary']

然后,您可以从此列表中设置要在其上执行GridSearchCV的参数。