从每个类标签中抽取X个示例

时间:2018-01-24 14:45:44

标签: python random scikit-learn resampling sklearn-pandas

我有一个带50 classes的数据集(numpy向量)和9000个训练样例。

x_train=(9000,2048)
y_train=(9000,)  # Classes are strings 
classes=list(set(y_train))

我想构建一个子数据集,使每个类都有5个例子

这意味着我得到5*50=250个训练样例。因此,我的子数据集将采用以下形式:

sub_train_data=(250,2048)
sub_train_labels=(250,)

备注:我们从每个班级随机抽取5个例子(班级总数= 50)

谢谢

3 个答案:

答案 0 :(得分:1)

以下是该问题的解决方案:

from collections import Counter
import numpy as np
import matplotlib.pyplot as plt

def balanced_sample_maker(X, y, sample_size, random_seed=42):
    uniq_levels = np.unique(y)
    uniq_counts = {level: sum(y == level) for level in uniq_levels}

    if not random_seed is None:
        np.random.seed(random_seed)

    # find observation index of each class levels
    groupby_levels = {}
    for ii, level in enumerate(uniq_levels):
        obs_idx = [idx for idx, val in enumerate(y) if val == level]
        groupby_levels[level] = obs_idx
    # oversampling on observations of each label
    balanced_copy_idx = []
    for gb_level, gb_idx in groupby_levels.items():
        over_sample_idx = np.random.choice(gb_idx, size=sample_size, replace=True).tolist()
        balanced_copy_idx+=over_sample_idx
    np.random.shuffle(balanced_copy_idx)

    data_train=X[balanced_copy_idx]
    labels_train=y[balanced_copy_idx]
    if  ((len(data_train)) == (sample_size*len(uniq_levels))):
        print('number of sampled example ', sample_size*len(uniq_levels), 'number of sample per class ', sample_size, ' #classes: ', len(list(set(uniq_levels))))
    else:
        print('number of samples is wrong ')

    labels, values = zip(*Counter(labels_train).items())
    print('number of classes ', len(list(set(labels_train))))
    check = all(x == values[0] for x in values)
    print(check)
    if check == True:
        print('Good all classes have the same number of examples')
    else:
        print('Repeat again your sampling your classes are not balanced')
    indexes = np.arange(len(labels))
    width = 0.5
    plt.bar(indexes, values, width)
    plt.xticks(indexes + width * 0.5, labels)
    plt.show()
    return data_train,labels_train

X_train,y_train=balanced_sample_maker(X,y,10)

受到Scikit-learn balanced subsampling

的启发

答案 1 :(得分:0)

我通常使用scikit-learn的技巧。我使用StratifiedShuffleSplit函数。因此,如果我必须选择火车集合的1 / n分数,则将数据分为n倍,并将测试数据的比例(test_size)设置为1-1 / n。这是一个仅使用1/10数据的示例。

sp = StratifiedShuffleSplit(n_splits=1, test_size=0.9, random_state=seed)
  for train_index, _ in sp.split(x_train, y_train):
    x_train, y_train = x_train[train_index], y_train[train_index]

答案 2 :(得分:0)

您可以使用数据框作为输入(就像我的情况一样),并使用下面的简单代码:

col = target
nsamples = min(t4m[col].value_counts().values)
res = pd.DataFrame()
for val in t4m[col].unique():
  t = t4m.loc[t4m[col] == val].sample(nsamples)
  res = pd.concat([res, t], ignore_index=True).sample(frac=1)

col 是带有类的列的名称。代码找到少数类,打乱数据帧,然后从每个类中抽取少数类的大小样本。

然后你可以将结果转换回 np.array