在sklean.model_selection.StratifiedShuffleSplit中n_splits的目的是什么?

时间:2018-06-09 09:46:46

标签: python machine-learning scikit-learn

我最近开始使用sklearn并偶然发现了分层

ShuffleSplit功能。即使我理解它的概念以及它的意图,我也不太了解它需要运行的参数,例如 n_split 。根据sklearn的文档,写成

  

n_splits:int,default 10重新洗牌次数&分裂   迭代。

我最好的猜测是它告诉 StratifieShufflesplit 函数数据中的starta数。

2 个答案:

答案 0 :(得分:1)

n_splits是几乎每个交叉验证器的参数。通常,它确定您将创建多少个不同的验证(和培训)集。 如果您使用StratifiedShuffleSplit,则表示分层数 - 这些分数隐含在数据集中分类目标的基础相对频率中。

请参阅以下官方文档的引用(完整链接here

  

<强> StratifiedShuffleSplit

     

StratifiedShuffleSplit是ShuffleSplit的一种变体,它会返回   分层分裂,即通过保留分裂产生分裂   每个目标类的百分比,如完整集合。

答案 1 :(得分:-1)

对于StratifiedShuffleSplit,“ n_split”指定需要按照“ test_size”中提到的比例从每个层中采样数据的次数。

示例:这是一个包含4个层次的数据集的示例,每个层次均包含3个记录。 n_split = 3且test_size = 0.3给出:

测试数据在每个层次中都有一个记录,总体上约占数据集的30%,即12个记录中的4个。

---Example dataset---
df
​
ProductName Quantities
0   Mobile  20
1   Mobile  15
2   Mobile  12
3   PC  10
4   PC  8
5   PC  9
6   Tablet  5
7   Tablet  3
8   Tablet  4
9   RasPi   2
10  RasPi   1
11  RasPi   3


from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=3,test_size=0.3,random_state=42)

for train_index,test_index in split.split(df,df['ProductName']):
    split_train = df.loc[train_index]
    split_test = df.loc[test_index]
    print("train:")
    print(split_train)
    print("test:")
    print(split_test)

---split_1---
train:
   ProductName  Quantities
5           PC           9
8       Tablet           4
0       Mobile          20
9        RasPi           2
10       RasPi           1
4           PC           8
1       Mobile          15
7       Tablet           3
test:
   ProductName  Quantities
2       Mobile          12
3           PC          10
11       RasPi           3
6       Tablet           5
---split_2---
train:
   ProductName  Quantities
7       Tablet           3
3           PC          10
2       Mobile          12
8       Tablet           4
9        RasPi           2
11       RasPi           3
1       Mobile          15
5           PC           9
test:
   ProductName  Quantities
4           PC           8
0       Mobile          20
6       Tablet           5
10       RasPi           1
---split_3---
train:
   ProductName  Quantities
2       Mobile          12
9        RasPi           2
11       RasPi           3
3           PC          10
0       Mobile          20
6       Tablet           5
8       Tablet           4
4           PC           8
test:
   ProductName  Quantities
5           PC           9
7       Tablet           3
1       Mobile          15
10       RasPi           1