考虑数组a
a = np.array([3, 3, np.nan, 3, 3, np.nan])
我能做到
np.isnan(a).argmax()
但这需要找到所有np.nan
才能找到第一个
有更有效的方法吗?
我一直试图弄清楚我是否可以将参数传递给np.argpartition
,以便np.nan
首先排序而不是最后排序。
关于[dup]的编辑。
这个问题有几个不同的原因。
isnan
。关于第二次[dup]的编辑。
仍然处理平等和问题/答案已经过时,很可能已经过时了。
答案 0 :(得分:11)
考虑numba.jit
也可能值得;没有它,矢量化版本可能会在大多数情况下击败直接的纯Python搜索,但在编译代码后,普通搜索将起带头作用,至少在我的测试中:
In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [70]: %paste
import numba
def naive(a):
for i in range(len(a)):
if np.isnan(a[i]):
return i
def short(a):
return np.isnan(a).argmax()
@numba.jit
def naive_jit(a):
for i in range(len(a)):
if np.isnan(a[i]):
return i
@numba.jit
def short_jit(a):
return np.isnan(a).argmax()
## -- End pasted text --
In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop
In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 µs per loop
In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 µs per loop
In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 µs per loop
编辑:正如@hpaulj在回答中指出的那样,numpy
实际上附带了优化的短路搜索,其性能与上面的JITted搜索相当:
In [26]: %paste
def plain(a):
return a.argmax()
@numba.jit
def plain_jit(a):
return a.argmax()
## -- End pasted text --
In [35]: %timeit naive(a)
100 loops, best of 3: 7.13 ms per loop
In [36]: %timeit plain(a)
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.04 µs per loop
In [37]: %timeit naive_jit(a)
100000 loops, best of 3: 6.91 µs per loop
In [38]: %timeit plain_jit(a)
10000 loops, best of 3: 125 µs per loop
答案 1 :(得分:8)
我将提名
a.argmax()
使用@fuglede's
测试数组:
In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999
In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop
In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop
我没有安装numba
,所以可以比较一下。但是我相对于short
的加速比大于@fuglede's
6x。
我在Py3中测试,它接受<np.nan
,而Py2会引发运行时警告。但代码搜索表明这并不依赖于这种比较。
/numpy/core/src/multiarray/calculation.c
PyArray_ArgMax
使用轴(将感兴趣的内容移动到最后),并将操作委托给arg_func = PyArray_DESCR(ap)->f->argmax
,这是一个取决于dtype的函数。
在numpy/core/src/multiarray/arraytypes.c.src
中,它看起来像BOOL_argmax
短路,一遇到True
就会立即返回。
for (; i < n; i++) {
if (ip[i]) {
*max_ind = i;
return 0;
}
}
并且@fname@_argmax
也会在最大nan
上发生短路。 np.nan
是最大的&#39;也在argmin
。
#if @isfloat@
if (@isnan@(mp)) {
/* nan encountered; it's maximal */
return 0;
}
#endif
欢迎来自经验丰富的c
编码员的评论,但在我看来,至少对于np.nan
来说,普通argmax
将会得到我们能够获得的速度。
在生成9999
时使用a
表示a.argmax
时间取决于该值,与短路一致。
答案 2 :(得分:6)
以下是使用itertools.takewhile()
的pythonic方法:
from itertools import takewhile
sum(1 for _ in takewhile(np.isfinite, a))
使用generator_expression_within _ next
方法进行基准测试: 1
In [118]: a = np.repeat(a, 10000)
In [120]: %timeit next(i for i, j in enumerate(a) if np.isnan(j))
100 loops, best of 3: 12.4 ms per loop
In [121]: %timeit sum(1 for _ in takewhile(np.isfinite, a))
100 loops, best of 3: 11.5 ms per loop
但仍然(到目前为止)慢于numpy方法:
In [119]: %timeit np.isnan(a).argmax()
100000 loops, best of 3: 16.8 µs per loop
<子>
1.这种方法的问题是使用enumerate
函数。它首先从numpy数组返回一个enumerate
对象(这是一个像对象一样的迭代器),并且调用迭代器的生成器函数和next
属性需要时间。
子>
答案 3 :(得分:3)
在各种情况下寻找第一场比赛时,我们可以迭代并查找第一场比赛并在第一场比赛中退出,而不是去/处理整个阵列。所以,我们会采用Python's next function
的方法,就像这样 -
next((i for i, val in enumerate(a) if np.isnan(val)))
样品运行 -
In [192]: a = np.array([3, 3, np.nan, 3, 3, np.nan])
In [193]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[193]: 2
In [194]: a[2] = 10
In [195]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[195]: 5