课程数量

时间:2018-01-29 14:40:50

标签: python validation scikit-learn svm

我正在根据本教程http://scikit-learn.org/stable/modules/learning_curve.html

将以下代码应用于我自己的数据
from sklearn.model_selection import learning_curve
from sklearn.svm import SVC
train_sizes, train_scores, valid_scores = learning_curve(SVC(kernel='linear'), X, y, train_sizes=[50, 80, 110], cv=5)

但是,我收到以下错误ValueError: The number of classes has to be greater than one; got 1

这是我的X和y:

X.shape (2163, 8891)

y.shape (2163,)

type(X) <class 'numpy.ndarray'>

type(y) <class 'numpy.ndarray'>

使用print(set(y))会产生两个类{'R', 'N'}

有关如何导致此错误的任何想法?

1 个答案:

答案 0 :(得分:2)

可能由于cv = 5而发生。由于您使用的是整数,因此将使用简单的K-Fold迭代器,这可能会以这样的方式分割数据:在给定的训练折叠中,只存在单个类

请尝试使用StratifiedKFold

from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5)
train_sizes, train_scores, valid_scores = learning_curve(SVC(kernel='linear'), 
                                                     X, y, 
                                                     train_sizes=[50, 80, 110], 
                                                     cv=skf)