将管道与GridSearchCV一起使用

时间:2020-06-30 17:52:10

标签: scikit-learn svm pipeline grid-search

假设我有这个VALUE对象:

07/02/2002

要将超参数传递到支持向量分类器(SVC),我可以执行以下操作:

d = []
for h in a:
   scanner = {
      'ScannerName' : h['name'],
      'AntennaNumber' : [],
      'LastScanDate' : []
   }

   for antennae in h['antennae']:
      scanner['AntennaNumber'].append(antennae['antenna'])
      scanner['LastScanDate'].append(antennae['lastScanDate'])

   d.append(scanner)

print(d)

然后,我可以使用Pipeline

from sklearn.pipeline import Pipeline
pipe = Pipeline([
    ('my_transform', my_transform()),
    ('estimator', SVC())
])

我们知道 linear 内核不使用gamma作为超参数。 那么,如何在此GridSearch中包含 linear 内核?

例如,在简单的pipe_parameters = { 'estimator__gamma': (0.1, 1), 'estimator__kernel': (rbf) } (没有管道)中,我可以这样做:

GridSearchCV

因此,我需要这种代码的有效版本:

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(pipe, pipe_parameters)
grid.fit(X_train, y_train)

我想将以下组合用作超参数:

GridSearch

1 个答案:

答案 0 :(得分:2)

您快到了。类似于为SVC模型创建多个词典的方法,为管道创建词典列表。

尝试以下示例:

from sklearn.datasets import fetch_20newsgroups
from sklearn.pipeline import pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.svm import SVC

categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
remove = ('headers', 'footers', 'quotes')

data_train = fetch_20newsgroups(subset='train', categories=categories,
                                shuffle=True, random_state=42,
                                remove=remove)

pipe = Pipeline([
    ('bag_of_words', CountVectorizer()),
    ('estimator', SVC())])
pipe_parameters = [
    {'bag_of_words__max_features': (None, 1500),
     'estimator__C': [ 0.1, ], 
     'estimator__gamma': [0.0001, 1],
     'estimator__kernel': ['rbf']},
    {'bag_of_words__max_features': (None, 1500),
     'estimator__C': [0.1, 1],
     'estimator__kernel': ['linear']}
]
from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(pipe, pipe_parameters, cv=2)
grid.fit(data_train.data, data_train.target)

grid.best_params_
# {'bag_of_words__max_features': None,
#  'estimator__C': 0.1,
#  'estimator__kernel': 'linear'}