为什么numba会破坏这个功能?

时间:2021-07-22 15:55:53

标签: python numba

我最近听说了 numba,我今天想测试一下。 我制作了一个简单的程序,它可以获取一个数字的阶乘,同时对其进行计时。

import time
from numba import jit

@jit()
def fact(args):
    x = 1
    for i in range(2, args + 1):
        x *= i
    return x

st = time.time()
x = fact(100000)
print(x)
et = time.time()

elapsed_time = et - st

print("Time elapsed: ", elapsed_time)

出于某种原因,当存在 @jit 装饰器时,此代码仅输出“0”,但没有 numba 时,代码工作正常。

为什么会发生这种情况,我该如何解决?

2 个答案:

答案 0 :(得分:1)

我相信您可能会遇到使用本机类型的 numba 问题。具体来说,原生 int32 值(最有可能的罪魁祸首)的范围有限,而 Python int 将根据需要尽可能大。

如果我在循环中添加打印语句以显示 x 并运行直到 x 为 0,则输出为:

2
6
24
120
720
5040
40320
362880
3628800
39916800
479001600
1932053504
1278945280
2004310016
2004189184
-288522240
-898433024
109641728
-2102132736
-1195114496
-522715136
862453760
-775946240
2076180480
-1853882368
1484783616
-1375731712
-1241513984
1409286144
738197504
-2147483648
-2147483648
0

如您所见,这些都适合在 int32 中,并且在值中到处跳跃。

另一方面,没有numba的输出是:

2
6
24
120
720
5040
40320
362880
3628800
39916800
479001600
6227020800
87178291200
1307674368000
20922789888000
355687428096000
6402373705728000
121645100408832000
2432902008176640000
51090942171709440000
1124000727777607680000
25852016738884976640000
620448401733239439360000
15511210043330985984000000
403291461126605635584000000
10888869450418352160768000000
304888344611713860501504000000
8841761993739701954543616000000
265252859812191058636308480000000
8222838654177922817725562880000000
263130836933693530167218012160000000
8683317618811886495518194401280000000
295232799039604140847618609643520000000
10333147966386144929666651337523200000000

从 Python 输出中取出 6227020800 行,即

1 0111 0011 0010 1000 1100 1100 0000 0000

二进制。修剪掉不适合 int32 的额外位为您提供 1932053504,这正是您在 numba 输出中看到的。

答案 1 :(得分:1)

import numpy as np
import numba
from numba import jit
from numpy import prod
import time


def factorial(n):
    print( prod(range(1,n+1)))
        
factorial(1)
@jit(nopython=True)
def fact(args):
    x = 1
    for i in range(2, args + 1):
        x *= i
    return x
x = 10
st = time.time()
y = fact(x)
print(y)
et = time.time()

elapsed_time = et - st

print("Time elapsed: ", elapsed_time)

1

3628800

经过的时间:0.15069580078125

它适用于像 10 这样的小值 x = 10 对于大值,您必须使用现有的 numpy 函数

@jit()
def factorial1(n):
   return(np.math.factorial(n));

st = time.time()
x = factorial1(100)
print(x)
et = time.time()

elapsed_time = et - st

print("Time elapsed: ", elapsed_time)