StratifiedShuffleSplit的奇怪结果

时间:2019-10-30 19:26:29

标签: python pandas

当我在分层过程之前删除一些行时,会收到奇怪的结果

机器学习。我需要调查关于数据组的ML结果

from sklearn.model_selection import StratifiedShuffleSplit

def stratifid(df, target, test_sz = 0.2):
 split = StratifiedShuffleSplit(n_splits = 1, test_size  = test_sz, random_state = 42)
 for tr_idx, te_idx in split.split(df, df[target]):
   train = df.loc[tr_idx]
   test  = df.loc[te_idx]
 return train, test

df = pd.DataFrame(data = {
    'gender' :      [1,  1,  0, 1,  1,  0,  0,  0,  1,  0, ],
    'age' :         [13, 45, 1, 45, 15, 16, 16, 16, 15, 15],
    'cholesterol' : [1,  2,  2, 1, 1, 1, 1, 1, 1, 1],
    'smoke' :       [0,  0,  1, 1, 7, 8, 3, 4, 4, 2]},
     dtype = np.int64)

df1 = df.loc[df['age'] > 13]

X_train, X_test = stratifid(df1, ['gender'], 0.2)
print(X_train)

I expect correct stratification for data. But my output is following:
   gender   age  cholesterol  smoke
0     NaN   NaN          NaN    NaN
4     1.0  15.0          1.0    7.0
1     1.0  45.0          2.0    0.0
6     0.0  16.0          1.0    3.0
3     1.0  45.0          1.0    1.0
7     0.0  16.0          1.0    4.0

Nan values are not expected ones....
If I make stratification for whole df( when df1 = df) all are Ok. What I'm doing wrong?

1 个答案:

答案 0 :(得分:0)

from sklearn.model_selection import StratifiedShuffleSplit
import pandas as pd
import numpy as np

def stratifid(df, target, test_sz = 0.2):
 split = StratifiedShuffleSplit(n_splits = 1, test_size  = test_sz, random_state = 42)
 for tr_idx, te_idx in split.split(df, df[target]):
   train = df.loc[tr_idx]
   test  = df.loc[te_idx]
 return train, test

df = pd.DataFrame(data = {
    'gender' :      [1,  1,  0, 1,  1,  0,  0,  0,  1,  0, ],
    'age' :         [13, 45, 1, 45, 15, 16, 16, 16, 15, 15],
    'cholesterol' : [1,  2,  2, 1, 1, 1, 1, 1, 1, 1],
    'smoke' :       [0,  0,  1, 1, 7, 8, 3, 4, 4, 2]},
     dtype = np.int64)

df1 = df.loc[df['age'] > 13]
df1.reset_index(inplace=True)
del df1['index']

X_train, X_test = stratifid(df1, ['gender'], 0.2)
print(X_train)