参数“stratify”来自方法“train_test_split”(scikit Learn)

时间:2016-01-17 19:05:34

标签: split scikit-learn training-data test-data

我正在尝试使用软件包scikit Learn中的train_test_split,但我遇到了参数stratify的问题。以下是代码:

from sklearn import cross_validation, datasets 

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

但是,我一直遇到以下问题:

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

有人知道发生了什么吗?以下是功能文档。

  

[...]

     

分层:数组类似或无(默认为无)

     

如果不是None,则数据以分层方式分割,使用此作为标签数组。

     

版本0.17中的新内容:分层拆分

     

[...]

5 个答案:

答案 0 :(得分:162)

stratify参数进行拆分,以便生成的样本中的值的比例与提供给参数stratify的值的比例相同。

例如,如果变量y是值为01的二进制分类变量,并且有25%的零和75%的零,stratify=y将确保你的随机分割有25%的0和75%的1

答案 1 :(得分:38)

Scikit-Learn只是告诉你它并没有认识到这个论点"分层"而不是你错误地使用它。这是因为参数是在0.17版本中添加的,如您引用的文档中所示。

所以你只需要更新Scikit-Learn。

答案 2 :(得分:31)

通过Google来到这里的未来自我:

train_test_split现在位于model_selection,因此:

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

是使用它的方式。设置random_state对于再现性是可取的。

答案 3 :(得分:7)

在此上下文中,分层意味着train_test_split方法返回具有与输入数据集相同比例的类标签的训练和测试子集。

答案 4 :(得分:2)

尝试运行此代码,它只是工作":

from sklearn import cross_validation, datasets 

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])