计算numba nopython函数中阶乘的最快方法

时间:2017-06-03 16:32:47

标签: python performance factorial numba

我有一个我想用numba编译的函数,但是我需要在该函数中计算一个阶乘。不幸的是numba不支持math.factorial

import math
import numba as nb

@nb.njit
def factorial1(x):
    return math.factorial(x)

factorial1(10)
# UntypedAttributeError: Failed at nopython (nopython frontend)

我看到它支持math.gamma(可用于计算阶乘),但与实际math.gamma函数相反,它不会返回表示“整数值”的浮点数:

@nb.njit
def factorial2(x):
    return math.gamma(x+1)

factorial2(10)
# 3628799.9999999995  <-- not exact

math.gamma(11)
# 3628800.0  <-- exact

math.factorial相比速度很慢:

%timeit factorial2(10)
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit math.factorial(10)
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

所以我决定定义自己的功能:

@nb.njit
def factorial3(x):
    n = 1
    for i in range(2, x+1):
        n *= i
    return n

factorial3(10)
# 3628800

%timeit factorial3(10)
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

它仍然比math.factorial慢,但它比基于math.gamma的numba函数更快,而且值“精确”。

所以我正在寻找一种最快的方法来计算nopython numba函数中正整数(&lt; = 20;以避免溢出)的factorial

1 个答案:

答案 0 :(得分:1)

对于值&lt; = 20,python正在使用查找表,如评论中所建议的那样。 https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452

LOOKUP_TABLE = np.array([
    1, 1, 2, 6, 24, 120, 720, 5040, 40320,
    362880, 3628800, 39916800, 479001600,
    6227020800, 87178291200, 1307674368000,
    20922789888000, 355687428096000, 6402373705728000,
    121645100408832000, 2432902008176640000], dtype='int64')

@nb.jit
def fast_factorial(n):
    if n > 20:
        raise ValueError
    return LOOKUP_TABLE[n]

从python调用它比python版本略慢,因为numba调度开销。

In [58]: %timeit math.factorial(10)
10000000 loops, best of 3: 79.4 ns per loop

In [59]: %timeit fast_factorial(10)
10000000 loops, best of 3: 173 ns per loop

但是在另一个numba函数中调用可以更快。

def loop_python():
    for i in range(10000):
        for n in range(21):
            math.factorial(n)

@nb.njit
def loop_numba():
    for i in range(10000):
        for n in range(21):
            fast_factorial(n)

In [65]: %timeit loop_python()
10 loops, best of 3: 36.7 ms per loop

In [66]: %timeit loop_numba()
10000000 loops, best of 3: 73.6 ns per loop