对于许多标签,获取与标签对应的行

时间:2018-11-06 05:05:49

标签: python arrays python-3.x sorting numpy

我有一个2D数组,其中每一行都有一个标签,该标签存储在单独的数组中(不一定是唯一的)。对于每个标签,我想从2D数组中提取具有该标签的行。我想要的一个基本的工作示例是这样:

import numpy as np

data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
label=np.array([1,1,1,0,1])

#very simple approach
label_values=np.unique(label)
res=[]
for la in label_values:
    data_of_this_label_val=data[label==la]
    res+=[data_of_this_label_val]
print(res)

结果(res)可以采用任何格式,只要它易于访问即可。在上面的示例中,应该是

[array([[20, 32]]), array([[ 1,  2],
   [ 3,  5],
   [ 7, 10],
   [ 0,  0]])]

请注意,我可以轻松地将列表中的每个元素关联到label_values中的唯一标签之一(即按索引)。

虽然可行,但使用for循环可能会花费大量时间,尤其是在我的标签向量很大的情况下。可以加快或编码得更优雅吗?

2 个答案:

答案 0 :(得分:3)

您可以argsort标签(这是unique在幕后的做法)。

如果您的标签是非负整数,如示例所示,您可以更便宜一些,请参见https://stackoverflow.com/a/53002966/7207392

>>> import numpy as np
>>> 
>>> data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
>>> label=np.array([1,1,1,0,1])
>>> 
>>> idx = label.argsort()
# use kind='mergesort' if you require a stable sort, i.e. one that
# preserves the order of equal labels
>>> ls = label[idx]
>>> split = 1 + np.where(ls[1:] != ls[:-1])[0]
>>> np.split(data[idx], split)
[array([[20, 32]]), array([[ 1,  2],
       [ 3,  5],
       [ 7, 10],
       [ 0,  0]])]

答案 1 :(得分:2)

很遗憾,groupby中没有内置的numpy函数,尽管您可以编写替代方法。但是,如果可以使用pandas,可以更简洁地解决您的问题:

import pandas as pd

res = pd.DataFrame(data).groupby(label).apply(lambda x: x.values).tolist()
# or, if performance is important, the following will be faster on large arrays, 
# but less readable IMO:
res = [data[i] for i in pd.DataFrame(data).groupby(label).groups.values()]

[array([[20, 32]]), array([[ 1,  2],
       [ 3,  5],
       [ 7, 10],
       [ 0,  0]])]