我是Numba的新手,我正在努力加快一些已经证明过于笨重的计算。我在下面给出的示例比较了一个包含我的计算子集的函数,该函数使用函数的矢量化/ numpy和numba版本,后者也通过注释掉@autojit装饰器来测试为纯python。
我发现numba和numpy版本相对于纯蟒蛇提供了类似的速度提升,两者都提高了10倍的速度。 numpy版本实际上比我的numba函数略快,但由于这个计算的4D性质,当numpy函数中的数组的大小比这个玩具示例大得多时,我很快耗尽内存。
这种速度很快,但是从纯蟒蛇到numba时,我经常看到网上的速度提升> 100倍。
我想知道在nopython模式下移动到numba时是否有一般的预期速度增加。我还想知道我的numba-ized函数中是否有任何组件会限制进一步的速度增加。
import numpy as np
from timeit import default_timer as timer
from numba import autojit
import math
def vecRadCalcs(slope, skyz, solz, skya, sola):
nloc = len(slope)
ntime = len(solz)
[lenz, lena] = skyz.shape
asolz = np.tile(np.reshape(solz,[ntime,1,1,1]),[1,nloc,lenz,lena])
asola = np.tile(np.reshape(sola,[ntime,1,1,1]),[1,nloc,lenz,lena])
askyz = np.tile(np.reshape(skyz,[1,1,lenz,lena]),[ntime,nloc,1,1])
askya = np.tile(np.reshape(skya,[1,1,lenz,lena]),[ntime,nloc,1,1])
phi1 = np.cos(asolz)*np.cos(askyz)
phi2 = np.sin(asolz)*np.sin(askyz)*np.cos(askya- asola)
phi12 = phi1 + phi2
phi12[phi12> 1.0] = 1.0
phi = np.arccos(phi12)
return(phi)
@autojit
def RadCalcs(slope, skyz, solz, skya, sola, phi):
nloc = len(slope)
ntime = len(solz)
pop = 0.0
[lenz, lena] = skyz.shape
for iiT in range(ntime):
asolz = solz[iiT]
asola = sola[iiT]
for iL in range(nloc):
for iz in range(lenz):
for ia in range(lena):
askyz = skyz[iz,ia]
askya = skya[iz,ia]
phi1 = math.cos(asolz)*math.cos(askyz)
phi2 = math.sin(asolz)*math.sin(askyz)*math.cos(askya- asola)
phi12 = phi1 + phi2
if phi12 > 1.0:
phi12 = 1.0
phi[iz,ia] = math.acos(phi12)
pop = pop + 1
return(pop)
zenith_cells = 90
azim_cells = 360
nloc = 10 # nominallly ~ 700
ntim = 10 # nominallly ~ 200000
slope = np.random.rand(nloc) * 10.0
solz = np.random.rand(ntim) *np.pi/2.0
sola = np.random.rand(ntim) * 1.0*np.pi
base = np.ones([zenith_cells,azim_cells])
skya = np.deg2rad(np.cumsum(base,axis=1))
skyz = np.deg2rad(np.cumsum(base,axis=0)*90/zenith_cells)
phi = np.zeros(skyz.shape)
start = timer()
outcalc = RadCalcs(slope, skyz, solz, skya, sola, phi)
stop = timer()
outcalc2 = vecRadCalcs(slope, skyz, solz, skya, sola)
stopvec = timer()
print(outcalc)
print(stop-start)
print(stopvec-stop)
答案 0 :(得分:0)
在运行numba 0.31.0的机器上,Numba版本比矢量化解决方案快2倍。当计时numba函数时,你需要多次运行该函数,因为你第一次看到jitting代码的时间+运行时间。由于Numba将jitted代码缓存在内存中,后续运行将不包括调整函数时间的开销。
另外,请注意,您的功能并不是计算相同的东西 - 您要小心,在结果上使用np.allclose
之类的内容来比较相同的内容。