numpy数组中的多个元素的索引

时间:2018-07-22 14:54:44

标签: python numpy

我有一个numpy数组和以下列表

y=np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x=[1,2,3]

我想返回一个数组的元组,每个数组包含y中x的每个元素的索引。 即

(array([[0,2,4]]),array([[1,6,7]]),array([[3,5]]))

是否可以以矢量化方式(没有任何循环)完成此操作?

4 个答案:

答案 0 :(得分:1)

尝试以下操作:

y = y.flatten()
[np.where(y == searchval)[0] for searchval in x]

答案 1 :(得分:1)

一种解决方案是map

y = y.reshape(1,len(y))
map(lambda k: np.where(y==k)[-1], x)

[array([0, 2, 4]), 
 array([1, 6, 7]), 
 array([3, 5])]

合理的性能。对于100000行,

%timeit list(map(lambda k: np.where(y==k), x))
3.1 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

答案 2 :(得分:0)

对于这个小例子,字典方法实际上更快(然后是wheres):

dd = {i:[] for i in [1,2,3]}
for i,v in enumerate(y):
   v=v[0]
   if v in dd:
       dd[v].append(i)
list(dd.values())

其他SO问题中也出现了此问题。已经提出了使用uniquesort的替代方案,但是它们更复杂,更难以重建-不一定更快。

对于numpy来说,这不是一个理想的问题。结果是数组列表或大小不同的列表,这是一个很好的线索,表明不可能使用简单的“矢量化”全数组解决方案。如果速度是一个足够重要的问题,则您可能需要查看numbacython的实现。

根据值的混合,不同的方法可能具有不同的相对时间。唯一值很少,但是长子列表可能更喜欢使用重复的where的方法。带有简短子列表的许多唯一值可能会支持在y上迭代的方法。

答案 3 :(得分:0)

您可以使用collections.defaultdict后加一个理解:

y = np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x = [1,2,3]

from collections import defaultdict

d = defaultdict(list)
for idx, item in enumerate(y.flat):
    d[item].append(idx)

res = tuple(np.array(d[k]) for k in x)

(array([0, 2, 4]), array([1, 6, 7]), array([3, 5]))