我正在尝试用numba优化一些代码。问题是简单的Cython优化(仅指定数据类型)比使用autojit快六倍,所以我不知道我是否做错了。
要优化的功能是:
from numba import autojit
@autojit(nopython=True)
def get_energy(system, i,j,m):
#system is an array, (i,j) some indices and m the size of the array
up=i-1; down=i+1; left=j-1; right=j+1
if up<0: total=system[m,j]
else: total=system[up,j]
if down>m: total+=system[0,j]
else: total+=system[down,j]
if left<0: total+=system[i,m]
else: total+=system[i,left]
if right>m: total+=system[i,0]
else: total+=system[i,right]
return 2*system[i,j]*total
简单的运行将是这样的:
import numpy as np
x=np.random.rand(50,50)
get_energy(x, 3, 5, 50)
我已经明白numba擅长循环,但可能无法很好地优化其他事情。无论如何,我希望与Cython有类似的性能,numba访问数组或条件语句的速度是否较慢?
Cython中的.pyx文件是:
import numpy as np
cimport cython
cimport numpy as np
def get_energy(np.ndarray[np.float64_t, ndim=2] system, int i,int j,unsigned int m):
cdef int up
cdef int down
cdef int left
cdef int right
cdef np.float64_t total
up=i-1; down=i+1; left=j-1; right=j+1
if up<0: total=system[m,j]
else: total=system[up,j]
if down>m: total+=system[0,j]
else: total+=system[down,j]
if left<0: total+=system[i,m]
else: total+=system[i,left]
if right>m: total+=system[i,0]
else: total+=system[i,right]
return 2*system[i,j]*total
如果我需要提供更多信息,请发表评论。