给出两个numpy数组,即:
images.shape: (60000, 784) # An array containing 60000 images
labels.shape: (60000, 10) # An array of labels for each image
labels
的每一行在特定索引处包含1
,以指示images
中相关示例的类。 (所以[0 0 1 0 0 0 0 0 0 0]
表示该示例属于Class 2(假设我们的类索引从0开始)。
我正在尝试有效地分隔images
,以便我可以同时操作属于特定类的所有图像。最明显的解决方案是使用for
循环(如下所示)。但是,我不确定如何过滤images
,以便仅返回具有相应labels
的那些。
for i in range(0, labels.shape[1]):
class_images = # (?) Array containing all images that belong to class i
顺便说一下,我也想知道是否有更有效的方法可以消除for
循环的使用。
答案 0 :(得分:1)
一种方法是将您的标签数组转换为bool并将其用于索引:
classes = []
blabels = labels.astype(bool)
for i in range(10):
classes.append(images[blabels[:, i], :])
或者使用列表理解作为单行代码:
classes = [images[l.astype(bool), :] for l in labels.T]
答案 1 :(得分:0)
_classes= [[] for x in range(10)]
for image_index , element in enumerate(labels):
_classes[element.index(1)].append(image_index)
例如_classes [0]将包含被归类为class0的图像索引。