Scikit-learn多输出分类器使用:GridSearchCV,Pipeline,OneVsRestClassifier,SGDClassifier

时间:2016-11-01 00:43:13

标签: python scikit-learn multilabel-classification

我正在尝试使用GridSearchCV和Pipeline构建一个多输出模型。 Pipeline给我带来了麻烦,因为标准分类器示例没有包含分类器的OneVsRestClassifier()。我使用scikit-learn 0.18和python 3.5

## Pipeline: Train and Predict
## SGD: support vector machine (SVM) with gradient descent
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier

clf = Pipeline([
               ('vect', CountVectorizer(ngram_range=(1,3), max_df=0.50 ) ),
               ('tfidf', TfidfTransformer() ),
               ('clf', SGDClassifier(loss='modified_huber', penalty='elasticnet',
                                          alpha=1e-4, n_iter=5, random_state=42,
                                          shuffle=True, n_jobs=-1) ),
                ])

ovr_clf = OneVsRestClassifier(clf ) 

from sklearn.model_selection import GridSearchCV
parameters = {'vect__ngram_range': [(1,1), (1,3)],
              'tfidf__norm': ('l1', 'l2', None),
              'estimator__loss': ('modified_huber', 'hinge',),
             }

gs_clf = GridSearchCV(estimator=pipeline, param_grid=parameters, 
                      scoring='f1_weighted', n_jobs=-1, verbose=1)
gs_clf = gs_clf.fit(X_train, y_train)

但这会产生错误: ....

  

ValueError:估算器的参数估算器无效   管道(步骤= [(' vect',CountVectorizer(analyzer =' word',   binary = False,decode_error =' strict',           dtype =,encoding =' utf-8',input =' content',           lowercase = True,max_df = 0.5,max_features = None,min_df = 1,           ngram_range =(1,3),预处理器=无,stop_words =无,           strip ... er_t = 0.5,random_state = 42,shuffle = True,          verbose = 0,warm_start = False),             n_jobs = -1))])。使用estimator.get_params().keys()检查可用参数列表。

那么使用param_grid和Pipeline通过OneVsRestClassifier将参数传递给clf的正确方法是什么?我是否需要将矢量化器和tdidf与管道中的分类器分开?

1 个答案:

答案 0 :(得分:12)

将OneVsRestClassifier()作为管道本身的一步,并将SGDClassifier作为OneVsRestClassifier的估算器。 你可以这样。

pipeline = Pipeline([
               ('vect', CountVectorizer(ngram_range=(1,3), max_df=0.50 ) ),
               ('tfidf', TfidfTransformer() ),
               ('clf', OneVsRestClassifier(SGDClassifier(loss='modified_huber', penalty='elasticnet',
                                          alpha=1e-4, n_iter=5, random_state=42,
                                          shuffle=True, n_jobs=-1) )),
                ])

其余代码可以保持不变。 OneVsRestClassifier充当其他估算器的包装器。