在下面的代码中,numba.jit生成7732.96 ...,而python给出-6351.97 ...(为简洁起见,省略了数字)。我该怎么做才能解决此问题?这是numba的错误还是我的编码错误?我在Spyder 3上使用了Python 3.7(anaconda)。
from numba import jit
import numpy as np
@jit(nopython=True)
def test(n):
sum = 0.0
arr = np.arange(2, n)
for x in np.sin(np.cos(arr ** 2)):
sum += x
return sum
a = test(100000000)
print(a)
答案 0 :(得分:1)
我无法重现您的错误:
from numba import jit
import numpy as np
def test(n):
sum = 0.0
arr = np.arange(2, n)
for x in np.sin(np.cos(arr ** 2)):
sum += x
return sum
testnb = jit(nopython=True)(test)
N = 100000000
print(test(N))
print(testnb(N))
# 7732.969676855288
# 7732.969676855337
我正在使用numba 0.45.1,python 3.7.3和numpy 1.16.4。我最初的猜测是,存在某种浮点问题,其中非{j1}形式的sum
是具有无限精度的python值,而在jitted代码中,sum
被键入为具体的float32或float64取决于您的系统。但是对于您的特定系统,我不确定发生了什么。