GridSearchCV用于python中的多类SVM

时间:2018-06-08 02:35:06

标签: python scikit-learn svm grid-search

我正在尝试学习如何为分类器找到最佳参数。所以,我使用 GridSearchCV 来解决多类分类问题。在Does not GridSearchCV support multi-class?上生成了一个虚拟代码。我正在使用n_classes = 3的代码。

import numpy as np
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler,label_binarize
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, make_scorer

X, y = make_classification(n_samples=3000, n_features=10, weights=[0.1, 0.9, 0.3],n_classes=3, n_clusters_per_class=1,n_informative=2)

pipe = make_pipeline(StandardScaler(), SVC(kernel='rbf', class_weight='auto'))

param_space = dict(svc__C=np.logspace(-5,0,5), svc__gamma=np.logspace(-2, 2, 10))

f1_score
my_scorer = make_scorer(f1_score, greater_is_better=True)

gscv = GridSearchCV(pipe, param_space, scoring=my_scorer)

我正在尝试按照Scikit-learn GridSearch giving "ValueError: multiclass format is not supported" error的建议执行单热编码。此外,有时会有像Toxic Comment Classification dataset on Kaggle这样的数据集,它会为您提供二值化标签。

y = label_binarize(y, classes=[0, 1, 2])
for i in classes:    
gscv.fit(X, y[i])

print gscv.best_params_

我得到了:

ValueError: bad input shape (2000L, 3L)

我不确定为什么会收到此错误。我的目标是找到多类分类问题的最佳参数。

1 个答案:

答案 0 :(得分:1)

代码的两个部分存在两个问题。

1)当你没有对标签进行单热编码时,让我们从第一部分开始。你看,SVC支持多类案件就好了。但是f1_score与(内部)GridSearchCV结合使用时不会。

默认情况下,

f1_score会在二进制分类的情况下返回正标签的分数,因此会在您的情况下抛出错误。

OR 它也可以返回一个分数数组(每个类一个),但GridSearchCV只接受一个值作为分数,因为它需要找到最佳分数和超级分数的最佳组合参数。因此,您需要通过f1_score中的平均方法从数组中获取单个值。

根据f1_score documentation,允许采用以下平均方法:

  

average:string,[None,'binary'(默认),'micro','macro',   'samples','加权']

所以改变你的make_scorer:

my_scorer = make_scorer(f1_score, greater_is_better=True, average='micro')

根据您的需要更改上面的'average'参数。

2)现在进入第二部分:当您对标签进行单热编码时,y的形状变为2-d,但SVC仅支持1-d数组{{ 1}}如文档中所述:

y

但即使您对标签进行编码并使用支持2-d标签的分类器,也必须解决第一个错误。因此,我建议您不要对标签进行单热编码,只需更改fit(X, y, sample_weight=None)[source] X : {array-like, sparse matrix}, shape (n_samples, n_features) y : array-like, shape (n_samples,)