我正在n_splits = 1的python中使用StratifiedShuffleSplit,
我不明白为什么我仍然需要一个for循环来获取输出?为什么以下代码不起作用?
split=StratifiedShuffleSplit(n_splits=1,test_size=0.2,random_state=42)
train_index, test_index = split.split(housing, housing["income_cat"])
这是原始代码
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
for train_index, test_index in split.split(housing, housing["income_cat"]):
strat_train_set = housing.loc[train_index]
strat_test_set = housing.loc[test_index]
答案 0 :(得分:0)
正如@Vivek Kumar所评论的那样,代码第二行中的split.split()
调用返回一个可迭代的对象(很可能是生成器,而不是列表或类似的东西)。您的非工作示例尝试使用返回值,就好像不是。
让我们看看循环消耗什么样的数据:
for train_index, test_index in ...:
...
for
循环显然需要一个可迭代的。另外,train_index, test_index
将可迭代项中的每个项“分解”为两个值,因此每个项必须是具有两个元素的可迭代项。通常,在这种情况下会使用元组。
因此,split.split()
的结果可能类似于:
[
(a1, b1),
(a2, b2),
...
]
大概n_splits=1
意味着将只有一对train_index, test_index
-至少这是您似乎主张并需要验证的东西。在这种情况下,结果将是这样:
[
(a1, b1),
]
所以只有一个项目,它本身就是具有两个项目的元组。现在,您尝试使用train_index, test_index = ...
破坏单个项目的结构,这将失败:项目数不匹配。您需要先提取元组。
有两种获取元组的基本方法:
pair = split.split(...)[0]
pair, = split.split(...)
我会强烈建议第二个变体,因为当意外地有一个以上的项目时,它会失败;第一个变体只会默默地丢弃多余的物品。
然后,您可以破坏元组:
train_index, test_index = pair
或者两者都一步一步完成:
split = StratifiedShuffleSplit(n_splits=1,test_size=0.2,random_state=42)
(train_index, test_index), = split.split(housing, housing["income_cat"])