NumPy-获取行总和大于10的行的索引

时间:2019-09-27 10:03:50

标签: python numpy

我有以下数组:

a = np.array([1,2,9], [5,2,4], [1,2,3])

任务是查找所有行总和大于 10 的索引,在我的示例中,结果应类似于[0,1]

我需要一个类似于本文推荐的过滤器: Filter rows of a numpy array?

但是,我只需要索引,而不需要实际值或它们自己的数组。

我当前的代码如下:

temp = a[np.sum(a, axis=1) > 5]

如何获取过滤后的行的初始索引?

4 个答案:

答案 0 :(得分:4)

您可以像这样使用np.argwhere()

>>> import numpy as np

>>> a = np.array([[1,2,9], [5,2,4], [1,2,3]])
>>> np.argwhere(np.sum(a, axis=1) > 10)
[[0]
 [1]]

答案 1 :(得分:1)

您可以检查总和大于10的位置,并使用np.flatnonzero获取索引:

a = np.array([[1,2,9], [5,2,4], [1,2,3]])

np.flatnonzero(a.sum(1) > 10)
# array([0, 1], dtype=int64)

答案 2 :(得分:0)

我尝试了多个代码。最好的似乎是第二个版本:

import numpy as np
a = np.array([[1,2,9], [5,2,4], [1,2,1]])
print(a)

%timeit temp = a[np.sum(a, axis=1) > 5]
temp = a[np.sum(a, axis=1) > 5]
print(temp)

%timeit temp = [n for n, curr in enumerate(a) if sum(curr) > 5 ]
temp = [n for n, curr in enumerate(a) if sum(curr) > 5 ]
print(temp)

%timeit temp = np.argwhere(np.sum(a, axis=1) > 5)
temp = np.argwhere(np.sum(a, axis=1) > 5)
print(temp)

%timeit temp = np.flatnonzero(a.sum(1) > 10)
temp = np.flatnonzero(a.sum(1) > 10)
print(temp)

结果是:

[[1 2 9]
 [5 2 4]
 [1 2 1]]
The slowest run took 12.37 times longer than the fastest. This could mean that     an intermediate result is being cached.
100000 loops, best of 3: 7.47 µs per loop
[[1 2 9]
 [5 2 4]]
100000 loops, best of 3: 5.09 µs per loop
[0, 1]
The slowest run took 9.83 times longer than the fastest. This could mean that     an intermediate result is being cached.
100000 loops, best of 3: 13.3 µs per loop
[[0]
 [1]]
The slowest run took 6.78 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.8 µs per loop
[0 1]

答案 3 :(得分:0)

您可以简单地使用:

temp = np.sum(a, axis=1) > 10
np.arange(len(a))[temp]