我有一个2D数组,其中每一行都有一个标签,该标签存储在单独的数组中(不一定是唯一的)。对于每个标签,我想从2D数组中提取具有该标签的行。我想要的一个基本的工作示例是这样:
import numpy as np
data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
label=np.array([1,1,1,0,1])
#very simple approach
label_values=np.unique(label)
res=[]
for la in label_values:
data_of_this_label_val=data[label==la]
res+=[data_of_this_label_val]
print(res)
结果(res)可以采用任何格式,只要它易于访问即可。在上面的示例中,应该是
[array([[20, 32]]), array([[ 1, 2],
[ 3, 5],
[ 7, 10],
[ 0, 0]])]
请注意,我可以轻松地将列表中的每个元素关联到label_values
中的唯一标签之一(即按索引)。
虽然可行,但使用for循环可能会花费大量时间,尤其是在我的标签向量很大的情况下。可以加快或编码得更优雅吗?
答案 0 :(得分:3)
您可以argsort
标签(这是unique
在幕后的做法)。
如果您的标签是非负整数,如示例所示,您可以更便宜一些,请参见https://stackoverflow.com/a/53002966/7207392。
>>> import numpy as np
>>>
>>> data=np.array([[1,2],[3,5],[7,10], [20,32],[0,0]])
>>> label=np.array([1,1,1,0,1])
>>>
>>> idx = label.argsort()
# use kind='mergesort' if you require a stable sort, i.e. one that
# preserves the order of equal labels
>>> ls = label[idx]
>>> split = 1 + np.where(ls[1:] != ls[:-1])[0]
>>> np.split(data[idx], split)
[array([[20, 32]]), array([[ 1, 2],
[ 3, 5],
[ 7, 10],
[ 0, 0]])]
答案 1 :(得分:2)
很遗憾,groupby
中没有内置的numpy
函数,尽管您可以编写替代方法。但是,如果可以使用pandas
,可以更简洁地解决您的问题:
import pandas as pd
res = pd.DataFrame(data).groupby(label).apply(lambda x: x.values).tolist()
# or, if performance is important, the following will be faster on large arrays,
# but less readable IMO:
res = [data[i] for i in pd.DataFrame(data).groupby(label).groups.values()]
[array([[20, 32]]), array([[ 1, 2],
[ 3, 5],
[ 7, 10],
[ 0, 0]])]