为什么不是“numpy.any”懒惰(短路)

时间:2017-08-19 12:29:52

标签: python performance numpy

我不明白为什么还没有进行如此基本的优化:

In [1]: %timeit np.ones(10**6).any()
100 loops, best of 3: 7.32 ms per loop

In [2]: %timeit np.ones(10**7).any()
10 loops, best of 3: 59.7 ms per loop

扫描整个阵列,即使结论是第一项的证据。

2 个答案:

答案 0 :(得分:6)

这是一个不固定的表现回归。 NumPy issue 3446.实际上 short-circuiting logic,但对ufunc.reduce机制的更改引入了围绕短路逻辑的不必要的基于块的外部循环,并且外部循环不知道如何短路。您可以看到对分块机制here的一些解释。

尽管如此,即使没有回归,短路效应也不会出现在您的测试中。首先,你要为数组创建计时,其次,我不认为他们曾为任何输入dtype输入短路逻辑,但布尔值。从讨论中可以看出,numpy.any背后的ufunc减少机制的细节会让这很困难。

讨论确实提出了令人惊讶的观点,即argminargmax方法似乎是布尔输入的短路。 A quick test显示自NumPy 1.12(不是最新版本,但当前在Ideone上的版本),x[x.argmax()]短路,并且它超出了x.any()x.max()对于1维布尔输入,无论输入是小还是大,无论短路是否得到回报。怪异!

答案 1 :(得分:5)

您需要为短路付出代价。您需要在代码中引入分支。

分支(例如if语句)的问题在于它们可能比使用替代操作(没有分支)慢,然后您还有分支预测,这可能包括显着的开销。

同样取决于编译器和处理器,无分支代码可以使用处理器矢量化。我不是这方面的专家,但也许某种SIMD或SSE?

我在这里使用numba,因为代码易于阅读且速度足够快,因此性能会因这些微小差异而改变:

import numba as nb
import numpy as np

@nb.njit
def any_sc(arr):
    for item in arr:
        if item:
            return True
    return False

@nb.njit
def any_not_sc(arr):
    res = False
    for item in arr:
        res |= item
    return res

arr = np.zeros(100000, dtype=bool)
assert any_sc(arr) == any_not_sc(arr)
%timeit any_sc(arr)
# 126 µs ± 7.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit any_not_sc(arr)
# 15.5 µs ± 962 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr.any()
# 31.1 µs ± 184 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

在没有分支的最坏情况下,它快了近10倍。但在最好的情况下,短路功能要快得多:

arr = np.zeros(100000, dtype=bool)
arr[0] = True
%timeit any_sc(arr)
# 1.97 µs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit any_not_sc(arr)
# 15.1 µs ± 368 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit arr.any()
# 31.2 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

所以应该优化哪种情况的问题:最好的情况?最糟糕的情况?平均情况(any的平均情况是什么?)

可能是NumPy开发人员希望优化最坏情况而不是最佳情况。或者他们只是不关心?或许他们只是想要"可预测"无论如何都要表现。

关于代码的注释:您可以衡量创建数组所需的时间以及执行any所需的时间。如果any发生短路,您的代码就不会注意到它!

%timeit np.ones(10**6)
# 9.12 ms ± 635 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit np.ones(10**7)
# 86.2 ms ± 5.15 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

对于支持你的问题的决定性时间,你应该使用它:

arr1 = np.ones(10**6)
arr2 = np.ones(10**7)
%timeit arr1.any()
# 4.04 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit arr2.any()
# 39.8 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)