如何确保StratifiedShuffleSplit保留不平衡的类比率?

时间:2018-12-13 16:14:45

标签: python cross-validation

我的数据集不平衡,在微调模型时,我需要确保StratifiedShuffleSplit实际上是从具有固有类比率的所有类中选取的。我该如何测试?

1 个答案:

答案 0 :(得分:0)

下面的函数测试比例为4:16的不平衡数据集。

def test_cv():
    from sklearn.model_selection import StratifiedShuffleSplit
    X = np.array([1, 5, 4, 3, 4, 5, 6, 5, 4, 5, 3, 4, 3, 2, 3, 4, 1, 9, 3, 5])
    y = np.array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

    print('Class Ratio 1 / Total = {0} / {1}'.format(len(y[y == 1]), len(y)))
    print('Class Ratio 0 / Total = {0} / {1}'.format(len(y[y == 0]), len(y)))
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.3, random_state=0)
    for train_index, test_index in sss.split(X, y):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        print('X_TRAIN: \t{0} X_TEST:{1}'.format(X_train, X_test))
        print('y_TRAIN: \t{0} y_TEST:{1}\n'.format(y_train, y_test))
        print('#1/#Train = {0} / {1}'.format(len(y_train[y_train==1]), len(y_train)))
        print('#1/#Test = {0} / {1}'.format(len(y_test[y_test == 1]), len(y_test)))

输出将显示少数群体(1)在每个包含初始类别比率的拆分中如何出现:

Class Ratio 1 / Total = 4 / 20
Class Ratio 0 / Total = 16 / 20
X_TRAIN:    [3 6 4 5 5 2 9 1 3 4 4 3 3 3] X_TEST:[1 5 5 5 4 4]
y_TRAIN:    [0 0 1 0 0 0 0 1 1 0 0 0 0 0] y_TEST:[0 0 1 0 0 0]

#1/#Train = 3 / 14
#1/#Test = 1 / 6
X_TRAIN:    [4 4 3 5 5 4 2 3 3 6 5 4 1 3] X_TEST:[1 5 9 5 3 4]
y_TRAIN:    [1 0 0 0 0 0 0 1 0 0 1 0 0 0] y_TEST:[1 0 0 0 0 0]

#1/#Train = 3 / 14
#1/#Test = 1 / 6
X_TRAIN:    [5 4 4 3 3 5 1 3 2 4 5 3 9 5] X_TEST:[4 5 4 1 6 3]
y_TRAIN:    [0 0 0 0 0 0 0 0 0 1 0 1 0 1] y_TEST:[0 0 0 1 0 0]

#1/#Train = 3 / 14
#1/#Test = 1 / 6
X_TRAIN:    [3 4 1 3 5 3 4 5 2 5 5 3 1 3] X_TEST:[4 9 4 4 5 6]
y_TRAIN:    [0 0 0 0 0 1 0 0 0 1 0 0 1 0] y_TEST:[0 0 0 1 0 0]

#1/#Train = 3 / 14
#1/#Test = 1 / 6
X_TRAIN:    [5 9 1 6 4 3 4 4 5 5 2 3 5 3] X_TEST:[3 1 3 4 4 5]
y_TRAIN:    [0 0 0 0 0 0 0 1 0 1 0 0 0 1] y_TEST:[0 1 0 0 0 0]

#1/#Train = 3 / 14
#1/#Test = 1 / 6