我使用了很多numba的jit
装饰器,我最近意识到numba中添加了新功能,特别是parallel
选项和stencil
装饰器。
模板非常适合制作更干净的代码,但经过几次测试后,它似乎只是美观,效率不高。以下是示例代码:
@numba.njit
def nb_jit(A, out):
for i in range(1, A.shape[0]-1):
out[i] = 0.5*(A[i+1] - A[i-1])
return out
@numba.njit(numba.float64[:](numba.float64[:], numba.float64[:]))
def nb_jit_typed(A, out):
for i in range(1, A.shape[0]-1):
out[i] = 0.5*(A[i+1] - A[i-1])
return out
@numba.njit(parallel=True)
def nb_jit_paral(A, out):
for i in numba.prange(1, A.shape[0]-1):
out[i] = 0.5*(A[i+1] - A[i-1])
return out
@numba.stencil
def s2(A):
return 0.5*(A[1] - A[-1])
@numba.njit
def nb_stencil(A):
return s2(A)
@numba.njit(parallel=True)
def nb_stencil_paral(A):
return s2(A)
我使用以下数组测试了这些函数:
import numpy as np
arr = np.random.rand(100000)
res = arr.copy()
它给了我以下执行时间(当然,我在timeit之前执行了一次每个函数!):
____________________________________________________
| %timeit nb_jit(arr, res) | 36 us |
| %timeit nb_jit_typed(arr, res) | 68 us |
| %timeit nb_jit_paral(arr, res) | 151 us |
| %timeit nb_stencil(arr) | 59 us |
| %timeit nb_stencil_paral(arr) | 241 us |
____________________________________________________
所以我想知道:
nb_jit_typed
比nb_jit
慢?在我的记忆中,最后一次测试时,情况恰恰相反。nb_jit_parallel
这么慢? 注意:
import numba
numba.__version__
' 0.37.0'
import multiprocessing
multiprocessing.cpu_count()
4
修改
使用time.time()(没有任何GUI),使用维度数组(1000000,)重复10000次重复相同的函数:
jit | 16.37 s
jit typed | 17.22 s
jit parallel | 18.45 s
stencil | 21.95 s
stencil paral | 24.48 s