我有一个非常大的数组,由0到N之间的整数组成,其中每个值至少出现一次。
我想知道,对于每个值 k ,我的数组中所有索引的数组的值等于 k 。
例如:
arr = np.array([0,1,2,3,2,1,0])
desired_output = {
0: np.array([0,6]),
1: np.array([1,5]),
2: np.array([2,4]),
3: np.array([3]),
}
现在,我正在通过range(N+1)
循环完成此操作,并且呼叫np.where
N次。
indices = {}
for value in range(max(arr)+1):
indices[value] = np.where(arr == value)[0]
这个循环是我代码中最慢的部分。 (arr==value
评估和np.where
调用都会占用大量时间。)是否有更有效的方法来执行此操作?
我也试过玩np.unique(arr, return_index=True)
,但这只是告诉我第一个索引,而不是所有索引。
答案 0 :(得分:7)
方法#1
这是一种将这些索引作为数组列表的矢量化方法 -
sidx = arr.argsort()
unq, cut_idx = np.unique(arr[sidx],return_index=True)
indices = np.split(sidx,cut_idx)[1:]
如果你想要将每个独特元素与其索引相对应的最终字典,最后我们可以使用循环理解 -
dict_out = {unq[i]:iterID for i,iterID in enumerate(indices)}
方法#2
如果您只对阵列列表感兴趣,那么这里有一个替代性能 -
sidx = arr.argsort()
indices = np.split(sidx,np.flatnonzero(np.diff(arr[sidx])>0)+1)
答案 1 :(得分:3)
pythonic方式正在使用collections.defaultdict()
:
>>> from collections import defaultdict
>>>
>>> d = defaultdict(list)
>>>
>>> for i, j in enumerate(arr):
... d[j].append(i)
...
>>> d
defaultdict(<type 'list'>, {0: [0, 6], 1: [1, 5], 2: [2, 4], 3: [3]})
这是使用词典理解和numpy.where()
的Numpythonic方式:
>>> {i: np.where(arr == i)[0] for i in np.unique(arr)}
{0: array([0, 6]), 1: array([1, 5]), 2: array([2, 4]), 3: array([3])}
如果你不想涉及字典,这是一种纯粹的Numpythonic方法:
>>> uniq = np.unique(arr)
>>> args, indices = np.where((np.tile(arr, len(uniq)).reshape(len(uniq), len(arr)) == np.vstack(uniq)))
>>> np.split(indices, np.where(np.diff(args))[0] + 1)
[array([0, 6]), array([1, 5]), array([2, 4]), array([3])]
答案 2 :(得分:1)
我不知道numpy,但你肯定可以在一次迭代中使用defaultdict来做到这一点:
indices = defaultdict(list)
for i, val in enumerate(arr):
indices[val].append(i)
答案 3 :(得分:0)
使用numpy_indexed包的完全矢量化解决方案:
import numpy_indexed as npi
k, idx = npi.groupy_by(arr, np.arange(len(arr)))
更上一层楼;为什么你需要这些指数?后续的分组操作通常可以使用group_by功能[例如,npi.group_by(arr).mean(someotherarray)]更有效地计算,而无需显式计算密钥的索引。