我正在根据本教程http://scikit-learn.org/stable/modules/learning_curve.html
将以下代码应用于我自己的数据from sklearn.model_selection import learning_curve
from sklearn.svm import SVC
train_sizes, train_scores, valid_scores = learning_curve(SVC(kernel='linear'), X, y, train_sizes=[50, 80, 110], cv=5)
但是,我收到以下错误ValueError: The number of classes has to be greater than one; got 1
这是我的X和y:
X.shape
(2163, 8891)
y.shape
(2163,)
type(X)
<class 'numpy.ndarray'>
type(y)
<class 'numpy.ndarray'>
使用print(set(y))
会产生两个类{'R', 'N'}
有关如何导致此错误的任何想法?
答案 0 :(得分:2)
可能由于cv = 5而发生。由于您使用的是整数,因此将使用简单的K-Fold迭代器,这可能会以这样的方式分割数据:在给定的训练折叠中,只存在单个类
请尝试使用StratifiedKFold。
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5)
train_sizes, train_scores, valid_scores = learning_curve(SVC(kernel='linear'),
X, y,
train_sizes=[50, 80, 110],
cv=skf)