这个错误对StratifiedShuffleSplit意味着什么?

时间:2018-03-26 02:35:50

标签: python pandas scikit-learn data-science sklearn-pandas

我对数据科学一般都很陌生,希望有人可以解释为什么这不起作用:

我正在使用以下网址中的广告数据集:“http://www-bcf.usc.edu/~gareth/ISL/Advertising.csv”,其中包含3个功能列(“电视”,“广播”,“报纸”)和1个标签列(“销售”)。我的完整数据集名为data

接下来,我尝试使用sklearn的StratifiedShuffleSplit函数将数据划分为训练和测试集。

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, random_state=0) # can use test_size=0.8
for train_index, test_index in split.split(data.drop("sales", axis=1), data["sales"]): # Generate indices to split data into training and test set.
    strat_train_set = data.loc[train_index]
    strat_test_set = data.loc[test_index]

我得到了这个ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

在另一个包含14个要素列和1个标签列的数据集上使用相同的代码可以适当地分隔数据。为什么不在这里工作?感谢。

1 个答案:

答案 0 :(得分:0)

我认为问题是你的data_y是2D矩阵。

但正如我在sklearn.model_selection.StratifiedShuffleSplit doc中看到的那样,它应该是1D向量。尝试将每行data_y编码为整数(它将被解释为一个类),并在使用split之后。

或者你的y可能是一个回归变量(连续数值数据)。(Vivek的链接)