使用sklearn中的train_test_split时,指定类中的最大项目数

时间:2018-09-04 20:02:04

标签: python scikit-learn scipy

我正在处理高度不平衡的数据集,并使用sklearn.model_selection中的train_test_split

我在这个数据集中有10000个项目,这些类的比率约为10/2/2/1,我正在寻找一种使train平衡的方法 并且我想在达到最大数量时停止向最大类中添加元素。

是否可以限制商品数量,我知道拆分后可以删除多余的商品,但是我想知道是否存在这样的选择吗?

1 个答案:

答案 0 :(得分:1)

在调用stratify函数时使用train_test_split参数。遵循documentation以获得更多信息。

对于30%的测试数据,您可以这样做

X_train,X_test, y_train, y_test = train_test_split(data, y_true, stratify=y_true, test_size=0.3)

data是您的总数据,而y_true是您的基本真值