找到第一个np.nan值的最有效方法是什么?

时间:2016-12-25 11:00:04

标签: python numpy

考虑数组a

a = np.array([3, 3, np.nan, 3, 3, np.nan])

我能做到

np.isnan(a).argmax()

但这需要找到所有np.nan才能找到第一个 有更有效的方法吗?

我一直试图弄清楚我是否可以将参数传递给np.argpartition,以便np.nan首先排序而不是最后排序。

关于[dup]的编辑。
这个问题有几个不同的原因。

  1. 这个问题和答案解决了价值观的平等问题。这与isnan
  2. 有关
  3. 这些答案都遭遇了我的答案所面临的同样问题。请注意,我提供了一个非常有效的答案,但强调了它的低效率。我正在寻找解决效率低下的问题。
  4. 关于第二次[dup]的编辑。

    仍然处理平等和问题/答案已经过时,很可能已经过时了。

4 个答案:

答案 0 :(得分:11)

考虑numba.jit也可能值得;没有它,矢量化版本可能会在大多数情况下击败直接的纯Python搜索,但在编译代码后,普通搜索将起带头作用,至少在我的测试中:

In [63]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])

In [70]: %paste
import numba

def naive(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

def short(a):
        return np.isnan(a).argmax()

@numba.jit
def naive_jit(a):
        for i in range(len(a)):
                if np.isnan(a[i]):
                        return i

@numba.jit
def short_jit(a):
        return np.isnan(a).argmax()
## -- End pasted text --

In [71]: %timeit naive(a)
100 loops, best of 3: 7.22 ms per loop

In [72]: %timeit short(a)
The slowest run took 4.59 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 37.7 µs per loop

In [73]: %timeit naive_jit(a)
The slowest run took 6821.16 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.79 µs per loop

In [74]: %timeit short_jit(a)
The slowest run took 395.51 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 144 µs per loop

编辑:正如@hpaulj在回答中指出的那样,numpy实际上附带了优化的短路搜索,其性能与上面的JITted搜索相当:

In [26]: %paste
def plain(a):
        return a.argmax()

@numba.jit
def plain_jit(a):
        return a.argmax()
## -- End pasted text --

In [35]: %timeit naive(a)
100 loops, best of 3: 7.13 ms per loop

In [36]: %timeit plain(a)
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 7.04 µs per loop

In [37]: %timeit naive_jit(a)
100000 loops, best of 3: 6.91 µs per loop

In [38]: %timeit plain_jit(a)
10000 loops, best of 3: 125 µs per loop

答案 1 :(得分:8)

我将提名

a.argmax()

使用@fuglede's测试数组:

In [1]: a = np.array([np.nan if i % 10000 == 9999 else 3 for i in range(100000)])
In [2]: np.isnan(a).argmax()
Out[2]: 9999
In [3]: np.argmax(a)
Out[3]: 9999
In [4]: a.argmax()
Out[4]: 9999

In [5]: timeit a.argmax()
The slowest run took 29.94 ....
10000 loops, best of 3: 20.3 µs per loop

In [6]: timeit np.isnan(a).argmax()
The slowest run took 7.82 ...
1000 loops, best of 3: 462 µs per loop

我没有安装numba,所以可以比较一下。但是我相对于short的加速比大于@fuglede's 6x。

我在Py3中测试,它接受<np.nan,而Py2会引发运行时警告。但代码搜索表明这并不依赖于这种比较。

/numpy/core/src/multiarray/calculation.c PyArray_ArgMax使用轴(将感兴趣的内容移动到最后),并将操作委托给arg_func = PyArray_DESCR(ap)->f->argmax,这是一个取决于dtype的函数。

numpy/core/src/multiarray/arraytypes.c.src中,它看起来像BOOL_argmax短路,一遇到True就会立即返回。

for (; i < n; i++) {
    if (ip[i]) {
        *max_ind = i;
        return 0;
    }
}

并且@fname@_argmax也会在最大nan上发生短路。 np.nan是最大的&#39;也在argmin

#if @isfloat@
    if (@isnan@(mp)) {
        /* nan encountered; it's maximal */
        return 0;
    }
#endif

欢迎来自经验丰富的c编码员的评论,但在我看来,至少对于np.nan来说,普通argmax将会得到我们能够获得的速度。

在生成9999时使用a表示a.argmax时间取决于该值,与短路一致。

答案 2 :(得分:6)

以下是使用itertools.takewhile()的pythonic方法:

from itertools import takewhile
sum(1 for _ in takewhile(np.isfinite, a))

使用generator_expression_within _ next方法进行基准测试: 1

In [118]: a = np.repeat(a, 10000)

In [120]: %timeit next(i for i, j in enumerate(a) if np.isnan(j))
100 loops, best of 3: 12.4 ms per loop

In [121]: %timeit sum(1 for _ in takewhile(np.isfinite, a))
100 loops, best of 3: 11.5 ms per loop

但仍然(到目前为止)慢于numpy方法:

In [119]: %timeit np.isnan(a).argmax()
100000 loops, best of 3: 16.8 µs per loop

<子> 1.这种方法的问题是使用enumerate函数。它首先从numpy数组返回一个enumerate对象(这是一个像对象一样的迭代器),并且调用迭代器的生成器函数和next属性需要时间。

答案 3 :(得分:3)

在各种情况下寻找第一场比赛时,我们可以迭代并查找第一场比赛并在第一场比赛中退出,而不是去/处理整个阵列。所以,我们会采用Python's next function的方法,就像这样 -

next((i for i, val in enumerate(a) if np.isnan(val)))

样品运行 -

In [192]: a = np.array([3, 3, np.nan, 3, 3, np.nan])

In [193]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[193]: 2

In [194]: a[2] = 10

In [195]: next((i for i, val in enumerate(a) if np.isnan(val)))
Out[195]: 5