如何从MNIST数据集中选择特定数量的每个类别

时间:2019-12-08 08:35:39

标签: python keras mnist

我正在使用tensorflow来处理Mnist。我需要使用每个类的特定数量的数据来训练我的网络(例如,每个位数500个样本)。 我找到了how to sort the DB with class labels

idx = np.argsort(y_train)
x_train_sorted = x_train[idx]
y_train_sorted = y_train[idx]

但是我该如何选择500个数字,然后将它们与随机播放结合起来呢?

1 个答案:

答案 0 :(得分:0)

如果您将DataFrame合二为一,则可以groupby贴上标签,然后获得headtail

import pandas as pd

df = pd.DataFrame({
    'X1': [1,2,3,4,5,6,7,8,9,10,11,12],
    'X2': [21,22,23,24,25,26,27,28,29,30,31,32],
    'label': ['a','a','a','a','b','b','b','b','c','c','c','c']
})

groups = df.groupby('label')

df2 = groups.head(2)    
#df2 = groups.apply(lambda x:x[:2]) # the same as head(2)
#df2 = groups.apply(lambda x:x.sample(frac=1)[:2]) # shuffled before get values

print(df2)

结果

   X1  X2 label
0   1  21     a
1   2  22     a
4   5  25     b
5   6  26     b
8   9  29     c
9  10  30     c

然后您可以将其洗牌并分成X_trainy_train

df2 = df2.sample(frac=1).reset_index(drop=True)

X_train = df2[['X1','X2']]
y_train = df2['label']

print(X_train)
print(y_train)