我真的是numba
的新手,所以我真的不明白为什么会这样。我正在程序(Ising模型)的瓶颈功能上运行njit
,但确实会减慢速度。
我的功能是:
@nb.njit#(nb.types.int64[:](nb.types.int64, nb.types.float64[:], nb.types.int64[:], nb.types.int64[:]))
def heat_bath_mcs(size, Z, s, neighbors):
for step in range(size):
choice = int(size*np.random.random())
energy_variation = (s[neighbors[4*choice]]+s[neighbors[4*choice+1]]
+ s[neighbors[4*choice+2]]+s[neighbors[4*choice+3]])
if np.random.random() < 1.0/(1.0+Z[int(energy_variation*0.5)+2]):
s[choice] = +1
else:
s[choice] = -1
return s
size
是一个整数,s
和neighbors
是2个长度为size
的python列表,都具有整数值,而Z
是一个长度为4的列表浮点值。
在我试图推断类型的地方,但是它给出了这个错误(只使用第一个代码中#之后的内容):
TypeError:参数类型int64没有匹配的定义, 反映列表(float64),反映列表(int64),反映列表(int64)
如果在每次调用该函数之前打印出numba
类型,则:
print(nb.typeof(size),nb.typeof(Z),nb.typeof(s),nb.typeof(neighbors))
结果:
int64 reflected list(float64) reflected list(int64) reflected list(int64)
所以我的问题是,为什么会这样?我想我做错了什么,如何改善我的代码以加快速度?