有效地查找数组中所有值的索引

时间:2016-08-18 08:48:26

标签: python numpy

我有一个非常大的数组,由0到N之间的整数组成,其中每个值至少出现一次。

我想知道,对于每个值 k ,我的数组中所有索引的数组的值等于 k

例如:

arr = np.array([0,1,2,3,2,1,0])
desired_output = {
    0: np.array([0,6]),
    1: np.array([1,5]),
    2: np.array([2,4]),
    3: np.array([3]),
    }

现在,我正在通过range(N+1)循环完成此操作,并且呼叫np.where N次。

indices = {}
for value in range(max(arr)+1):
    indices[value] = np.where(arr == value)[0]

这个循环是我代码中最慢的部分。 (arr==value评估和np.where调用都会占用大量时间。)是否有更有效的方法来执行此操作?

我也试过玩np.unique(arr, return_index=True),但这只是告诉我第一个索引,而不是所有索引。

4 个答案:

答案 0 :(得分:7)

方法#1

这是一种将这些索引作为数组列表的矢量化方法 -

sidx = arr.argsort()
unq, cut_idx = np.unique(arr[sidx],return_index=True)
indices = np.split(sidx,cut_idx)[1:]

如果你想要将每个独特元素与其索引相对应的最终字典,最后我们可以使用循环理解 -

dict_out = {unq[i]:iterID for i,iterID in enumerate(indices)}

方法#2

如果您只对阵列列表感兴趣,那么这里有一个替代性能 -

sidx = arr.argsort()
indices = np.split(sidx,np.flatnonzero(np.diff(arr[sidx])>0)+1)

答案 1 :(得分:3)

pythonic方式正在使用collections.defaultdict()

>>> from collections import defaultdict
>>> 
>>> d = defaultdict(list)
>>> 
>>> for i, j in enumerate(arr):
...     d[j].append(i)
... 
>>> d
defaultdict(<type 'list'>, {0: [0, 6], 1: [1, 5], 2: [2, 4], 3: [3]})

这是使用词典理解和numpy.where()的Numpythonic方式:

>>> {i: np.where(arr == i)[0] for i in np.unique(arr)}
{0: array([0, 6]), 1: array([1, 5]), 2: array([2, 4]), 3: array([3])}

如果你不想涉及字典,这是一种纯粹的Numpythonic方法:

>>> uniq = np.unique(arr)
>>> args, indices = np.where((np.tile(arr, len(uniq)).reshape(len(uniq), len(arr)) == np.vstack(uniq)))
>>> np.split(indices, np.where(np.diff(args))[0] + 1)
[array([0, 6]), array([1, 5]), array([2, 4]), array([3])]

答案 2 :(得分:1)

我不知道numpy,但你肯定可以在一次迭代中使用defaultdict来做到这一点:

indices = defaultdict(list)
for i, val in enumerate(arr):
    indices[val].append(i)

答案 3 :(得分:0)

使用numpy_indexed包的完全矢量化解决方案:

import numpy_indexed as npi
k, idx = npi.groupy_by(arr, np.arange(len(arr)))

更上一层楼;为什么你需要这些指数?后续的分组操作通常可以使用group_by功能[例如,npi.group_by(arr).mean(someotherarray)]更有效地计算,而无需显式计算密钥的索引。