Scikit学习是否有根据多个变量进行分层拆分的方法?

时间:2019-03-29 10:50:37

标签: python machine-learning scikit-learn data-science training-data

我正在处理一个数据框,我注意到3个变量对于预测标签确实非常重要。因此,我想在测试和训练集中拆分数据框,但不进行随机拆分,而是基于这3个变量进行分层拆分(以在培训集中保持与原始数据框相同的分布)。已经创建函数StratifiedShuffleSplit来处理标签,因此如果我没有记错的话,我只能指定一个变量而不是三个变量。有人可以帮助我吗?谢谢

1 个答案:

答案 0 :(得分:1)

此交叉验证对象是StratifiedKFold和ShuffleSplit的合并,返回分层的随机褶皱。折叠是通过保留每个类别的样本百分比来完成的。

注意:像ShuffleSplit策略一样,分层随机拆分并不能保证所有折痕都会有所不同,尽管对于较大的数据集来说这仍然很有可能。

>>> StratifiedShuffleSplit(n_splits=5, random_state=0, ...)
>>> for train_index, test_index in sss.split(X, y):
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]