使用另一个标签数组过滤numpy数组

时间:2017-03-10 00:27:22

标签: python arrays numpy filter

给出两个numpy数组,即:

images.shape: (60000, 784) # An array containing 60000 images
labels.shape: (60000, 10)  # An array of labels for each image

labels的每一行在特定索引处包含1,以指示images中相关示例的类。 (所以[0 0 1 0 0 0 0 0 0 0]表示该示例属于Class 2(假设我们的类索引从0开始)。

我正在尝试有效地分隔images,以便我可以同时操作属于特定类的所有图像。最明显的解决方案是使用for循环(如下所示)。但是,我不确定如何过滤images,以便仅返回具有相应labels的那些。

for i in range(0, labels.shape[1]):
  class_images = # (?) Array containing all images that belong to class i

顺便说一下,我也想知道是否有更有效的方法可以消除for循环的使用。

2 个答案:

答案 0 :(得分:1)

一种方法是将您的标签数组转换为bool并将其用于索引:

classes = []
blabels = labels.astype(bool)
for i in range(10):
    classes.append(images[blabels[:, i], :])

或者使用列表理解作为单行代码:

classes = [images[l.astype(bool), :] for l in labels.T]

答案 1 :(得分:0)

_classes= [[] for x in range(10)]
for image_index , element in enumerate(labels):
    _classes[element.index(1)].append(image_index)

例如_classes [0]将包含被归类为class0的图像索引。