数据集训练/测试拆分代码理解

时间:2020-02-25 23:42:11

标签: python scikit-learn

from sklearn.model_selection import StratifiedShuffleSplit

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]

我目前正在阅读动手学习ML,这则代码存在一些问题。我在python上没有太多经验,所以这可能是一个原因,但让我更清楚地说明一下。在这本书中,住房问题要求我们创建层次,以便数据集中每个层次都有足够的实例,我们使用我在此处未复制的代码来执行此操作,我显示的代码用于创建测试和训练集,使用特定的收入类别。第一和第二行代码很清楚,第三行是我迷路的地方。我们创建了测试0.2列0.8的拆分,但是从那时起到底发生了什么,for循环的作用是什么?

我浏览了几页以获取信息,但是并没有发现任何可以使情况清楚的信息,因此我非常感谢您的帮助。

预先感谢您的回答。

2 个答案:

答案 0 :(得分:1)

for循环只是获取用于拆分的索引,并调用原始数据的那些行以形成训练和测试集。

答案 1 :(得分:0)

如果您使用的是K折交叉验证,则可以使用StratifiedShuffleSplit更好,在交叉验证中,您可以用不同的方式划分训练和测试数据,然后计算K次迭代的结果平均值。

n_splits必须等于 K 值,而在您的情况下, K 是1,因此交叉验证毫无意义。我认为您最好使用sklearn.model_selection.train_test_split,这更有意义。