我有一个numpy数组和以下列表
y=np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x=[1,2,3]
我想返回一个数组的元组,每个数组包含y中x的每个元素的索引。 即
(array([[0,2,4]]),array([[1,6,7]]),array([[3,5]]))
是否可以以矢量化方式(没有任何循环)完成此操作?
答案 0 :(得分:1)
尝试以下操作:
y = y.flatten()
[np.where(y == searchval)[0] for searchval in x]
答案 1 :(得分:1)
一种解决方案是map
y = y.reshape(1,len(y))
map(lambda k: np.where(y==k)[-1], x)
[array([0, 2, 4]),
array([1, 6, 7]),
array([3, 5])]
合理的性能。对于100000行,
%timeit list(map(lambda k: np.where(y==k), x))
3.1 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
答案 2 :(得分:0)
对于这个小例子,字典方法实际上更快(然后是wheres):
dd = {i:[] for i in [1,2,3]}
for i,v in enumerate(y):
v=v[0]
if v in dd:
dd[v].append(i)
list(dd.values())
其他SO问题中也出现了此问题。已经提出了使用unique
和sort
的替代方案,但是它们更复杂,更难以重建-不一定更快。
对于numpy
来说,这不是一个理想的问题。结果是数组列表或大小不同的列表,这是一个很好的线索,表明不可能使用简单的“矢量化”全数组解决方案。如果速度是一个足够重要的问题,则您可能需要查看numba
或cython
的实现。
根据值的混合,不同的方法可能具有不同的相对时间。唯一值很少,但是长子列表可能更喜欢使用重复的where
的方法。带有简短子列表的许多唯一值可能会支持在y
上迭代的方法。
答案 3 :(得分:0)
您可以使用collections.defaultdict
后加一个理解:
y = np.array([[1],[2],[1],[3],[1],[3],[2],[2]])
x = [1,2,3]
from collections import defaultdict
d = defaultdict(list)
for idx, item in enumerate(y.flat):
d[item].append(idx)
res = tuple(np.array(d[k]) for k in x)
(array([0, 2, 4]), array([1, 6, 7]), array([3, 5]))