我有以下数组:
a = np.array([1,2,9], [5,2,4], [1,2,3])
任务是查找所有行总和大于 10 的索引,在我的示例中,结果应类似于[0,1]
我需要一个类似于本文推荐的过滤器: Filter rows of a numpy array?
但是,我只需要索引,而不需要实际值或它们自己的数组。
我当前的代码如下:
temp = a[np.sum(a, axis=1) > 5]
如何获取过滤后的行的初始索引?
答案 0 :(得分:4)
您可以像这样使用np.argwhere()
:
>>> import numpy as np
>>> a = np.array([[1,2,9], [5,2,4], [1,2,3]])
>>> np.argwhere(np.sum(a, axis=1) > 10)
[[0]
[1]]
答案 1 :(得分:1)
您可以检查总和大于10
的位置,并使用np.flatnonzero
获取索引:
a = np.array([[1,2,9], [5,2,4], [1,2,3]])
np.flatnonzero(a.sum(1) > 10)
# array([0, 1], dtype=int64)
答案 2 :(得分:0)
我尝试了多个代码。最好的似乎是第二个版本:
import numpy as np
a = np.array([[1,2,9], [5,2,4], [1,2,1]])
print(a)
%timeit temp = a[np.sum(a, axis=1) > 5]
temp = a[np.sum(a, axis=1) > 5]
print(temp)
%timeit temp = [n for n, curr in enumerate(a) if sum(curr) > 5 ]
temp = [n for n, curr in enumerate(a) if sum(curr) > 5 ]
print(temp)
%timeit temp = np.argwhere(np.sum(a, axis=1) > 5)
temp = np.argwhere(np.sum(a, axis=1) > 5)
print(temp)
%timeit temp = np.flatnonzero(a.sum(1) > 10)
temp = np.flatnonzero(a.sum(1) > 10)
print(temp)
结果是:
[[1 2 9]
[5 2 4]
[1 2 1]]
The slowest run took 12.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.47 µs per loop
[[1 2 9]
[5 2 4]]
100000 loops, best of 3: 5.09 µs per loop
[0, 1]
The slowest run took 9.83 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 13.3 µs per loop
[[0]
[1]]
The slowest run took 6.78 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.8 µs per loop
[0 1]
答案 3 :(得分:0)
您可以简单地使用:
temp = np.sum(a, axis=1) > 10
np.arange(len(a))[temp]