为什么要让For循环读取StratifiedShuffleSplit输出?

时间:2018-12-01 00:09:26

标签: python-3.x machine-learning scikit-learn

我正在n_splits = 1的python中使用StratifiedShuffleSplit,

我不明白为什么我仍然需要一个for循环来获取输出?为什么以下代码不起作用?

split=StratifiedShuffleSplit(n_splits=1,test_size=0.2,random_state=42) 
train_index, test_index = split.split(housing, housing["income_cat"])

这是原始代码

split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]

1 个答案:

答案 0 :(得分:0)

正如@Vivek Kumar所评论的那样,代码第二行中的split.split()调用返回一个可迭代的对象(很可能是生成器,而不是列表或类似的东西)。您的非工作示例尝试使用返回值,就好像不是。

让我们看看循环消耗什么样的数据:

for train_index, test_index in ...:
    ...

for循环显然需要一个可迭代的。另外,train_index, test_index将可迭代项中的每个项“分解”为两个值,因此每个项必须是具有两个元素的可迭代项。通常,在这种情况下会使用元组。

因此,split.split()的结果可能类似于:

[
    (a1, b1),
    (a2, b2),
    ...
]

大概n_splits=1意味着将只有一对train_index, test_index-至少这是您似乎主张并需要验证的东西。在这种情况下,结果将是这样:

[
    (a1, b1),
]

所以只有一个项目,它本身就是具有两个项目的元组。现在,您尝试使用train_index, test_index = ...破坏单个项目的结构,这将失败:项目数不匹配。您需要先提取元组。

有两种获取元组的基本方法:

pair = split.split(...)[0]
pair, = split.split(...)

我会强烈建议第二个变体,因为当意外地有一个以上的项目时,它会失败;第一个变体只会默默地丢弃多余的物品。

然后,您可以破坏元组:

train_index, test_index = pair

或者两者都一步一步完成:

split = StratifiedShuffleSplit(n_splits=1,test_size=0.2,random_state=42) 
(train_index, test_index), = split.split(housing, housing["income_cat"])