我有一个我想用numba
编译的函数,但是我需要在该函数中计算一个阶乘。不幸的是numba
不支持math.factorial
:
import math
import numba as nb
@nb.njit
def factorial1(x):
return math.factorial(x)
factorial1(10)
# UntypedAttributeError: Failed at nopython (nopython frontend)
我看到它支持math.gamma
(可用于计算阶乘),但与实际math.gamma
函数相反,它不会返回表示“整数值”的浮点数:
@nb.njit
def factorial2(x):
return math.gamma(x+1)
factorial2(10)
# 3628799.9999999995 <-- not exact
math.gamma(11)
# 3628800.0 <-- exact
与math.factorial
相比速度很慢:
%timeit factorial2(10)
# 1.12 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit math.factorial(10)
# 321 ns ± 6.12 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
所以我决定定义自己的功能:
@nb.njit
def factorial3(x):
n = 1
for i in range(2, x+1):
n *= i
return n
factorial3(10)
# 3628800
%timeit factorial3(10)
# 821 ns ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
它仍然比math.factorial
慢,但它比基于math.gamma
的numba函数更快,而且值“精确”。
所以我正在寻找一种最快的方法来计算nopython numba函数中正整数(&lt; = 20;以避免溢出)的factorial
。
答案 0 :(得分:1)
对于值&lt; = 20,python正在使用查找表,如评论中所建议的那样。 https://github.com/python/cpython/blob/3.6/Modules/mathmodule.c#L1452
LOOKUP_TABLE = np.array([
1, 1, 2, 6, 24, 120, 720, 5040, 40320,
362880, 3628800, 39916800, 479001600,
6227020800, 87178291200, 1307674368000,
20922789888000, 355687428096000, 6402373705728000,
121645100408832000, 2432902008176640000], dtype='int64')
@nb.jit
def fast_factorial(n):
if n > 20:
raise ValueError
return LOOKUP_TABLE[n]
从python调用它比python版本略慢,因为numba调度开销。
In [58]: %timeit math.factorial(10)
10000000 loops, best of 3: 79.4 ns per loop
In [59]: %timeit fast_factorial(10)
10000000 loops, best of 3: 173 ns per loop
但是在另一个numba函数中调用可以更快。
def loop_python():
for i in range(10000):
for n in range(21):
math.factorial(n)
@nb.njit
def loop_numba():
for i in range(10000):
for n in range(21):
fast_factorial(n)
In [65]: %timeit loop_python()
10 loops, best of 3: 36.7 ms per loop
In [66]: %timeit loop_numba()
10000000 loops, best of 3: 73.6 ns per loop