如何平衡numpy数组中的类?

时间:2017-05-28 23:16:20

标签: python arrays numpy multidimensional-array

我有2个numpy数组如下:

images 包含图片文件的名称(images.shape为(N,3,128,128)): image_1.jpg image_2.jpg image_3.jpg image_4.jpg

labels 包含相应的标签(0-3)(labels.shape是(N,)): 1 1 3 2

我面临的问题是这些课程是不平衡的,第3课>> 1> 2> 0.

我想通过以下方式平衡最终数据集:

  • 计算每个班级的图像(样本)数量
  • 获取具有最少图像数量的类的计数
  • 将该计数用作其他3个类的最大图像/标签数
  • 随机弹出imageslabels
  • 中其他3个班级的多余图片/标签

到目前为止,我正在使用Counter来识别每个类的图像数量:

from Collections import Counter
import numpy as np

count = Counter(labels)
print(count)

>>>Counter({'1': 2991, '0': 2953, '2': 2510, '3': 2488})

您如何建议我随机弹出imageslabels中的匹配元素,以便它们包含2488个0,1和2类的样本?

1 个答案:

答案 0 :(得分:1)

您可以使用np.random.choice创建一个整数值掩码,您可以将其应用于标签和图像以平衡数据集:

n = 2488

mask = np.hstack([np.random.choice(np.where(labels == l)[0], n, replace=False)
                      for l in np.unique(labels)])