有没有一种快速的方法来查找二维数组在3d数组中的所有索引?
我有这个3d numpy数组:
arr = np.array([
[[0,1],[0,2],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5],[0,5],[0,5]],
[[0,1],[0,2],[0,2],[0,2],[0,3],[0,4],[0,4],[0,4],[0,5],[0,5]],
[[0,1],[0,2],[0,3],[0,3],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5]]
])
我想找到[0,4]
出现的所有索引。
我试过这个:
whereInd = np.argwhere(arr == np.array([0,4]))
但它不起作用。 预期结果是:
[[0 3],[0 4],[1 5],[1 6],[1 7],[2 5],[2 6]]
另一个问题是,这会很快吗?因为我想将它用于(10000,100,2)
数组。
答案 0 :(得分:2)
使用argwhere()
是一个好主意,但您还需要使用all()
来获得所需的输出:
>>> np.argwhere((arr == [0, 4]).all(axis=2))
array([[0, 3],
[0, 4],
[1, 5],
[1, 6],
[1, 7],
[2, 5],
[2, 6]])
此处all()
用于在比较后检查每一行[True, True]
(即行等于[0, 4]
)。在3D数组中,axis=2
指向行。
这会将维数减少到两个,argwhere()
会返回所需的索引数组。
关于性能,此方法应该很快处理您指定大小的数组:
In [20]: arr = np.random.randint(0, 10, size=(10000, 100, 2))
In [21]: %timeit np.argwhere((arr == [0, 4]).all(axis=2))
10 loops, best of 3: 44.9 ms per loop
答案 1 :(得分:0)
我能想到的最简单的解决方案是:
import numpy as np
arr = np.array([
[[0,1],[0,2],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5],[0,5],[0,5]],
[[0,1],[0,2],[0,2],[0,2],[0,3],[0,4],[0,4],[0,4],[0,5],[0,5]],
[[0,1],[0,2],[0,3],[0,3],[0,3],[0,4],[0,4],[0,5],[0,5],[0,5]]
])
whereInd = []
for i,row in enumerate(arr):
for j,elem in enumerate(row):
if all(elem == [0,4]):
whereInd.append((i,j))
print whereInd
#prints [(0, 3), (0, 4), (1, 5), (1, 6), (1, 7), (2, 5), (2, 6)]
虽然np.argwhere
的任何解决方案都应该快10倍左右。