我有一个函数,我正在尝试使用Numba模块中的@jit装饰器加速。对我来说,必须尽可能加快速度,因为我的主要代码会调用此函数数百万次。这是我的功能:
from numba import jit, types
import Sweep #My own module, works fine
@jit(types.Tuple((types.complex128[:], types.float64[:]))(types.complex128[:], types.complex128[:], types.float64[:], types.float64[:], types.float64))
def MultiModeSL(Ef, Ef2, Nf, u, tijd ):
dEdt= np.zeros(nrModes, dtype=np.complex128)
dNdt0= np.zeros(nrMoments, dtype=np.complex128)
Efcon = np.conjugate(Ef)
for j in range(nrModes):
for n in range(nrMoments):
dEdt += 0.5 * CMx[:,j,n,0] * dg * (1+ A*1j) * Nf[n] * Ef[j] * np.exp( 1j* (Sweep.omega[j]-Sweep.omega) *tijd)
for k in range(nrModes):
if n==0:
dNdt0 += g* CMx[j, k, 0,:] * Efcon[j] * Ef[k] * np.exp( 1j* (Sweep.omega[k]-Sweep.omega[j]) *tijd)
dNdt0 += dg*(1+A*1j) * CMx[j,k,n,:] * Nf[n] * Efcon[j] * Ef[k] * np.exp( 1j* (Sweep.omega[k]-Sweep.omega[j]) *tijd)
dEdt += - 0.5*(pd-g)*Ef + fbr*Ef2 + Kinj*EAinj*(1 + np.exp(1j*(u+Vmzm)) )
dNdt = Sweep.Jn - Nf*ed - dNdt0.real
return dEdt, dNdt
该功能完美运行,没有Jit装饰器。但是,当我使用@jit运行它时,我收到此错误:
numba.errors.LoweringError: Failed at object (object mode frontend)
Failed at object (object mode backend)
dEdt.1
File "Functions.py", line 82
[1] During: lowering "$237 = call $236(Ef, Ef2, Efcon, Nf, dEdt.1, dNdt0, tijd, u)" at /home/humblebee/MEGA/GUI RC/General_Formula/Functions.py (82)
第82行对应于使用j作为迭代器的For循环。
你可以帮帮我吗?编辑: 根据Peter的建议并将其与Einsum相结合,我能够删除循环。这使我的功能 3 快了。这是新代码:
def MultiModeSL(Ef, Ef2, Nf, u, tijd ):
dEdt= np.zeros(nrModes, dtype=np.complex128)
dNdt0= np.zeros(nrMoments, dtype=np.complex128)
Efcon = np.conjugate(Ef)
dEdt = 0.5* np.einsum("k, jkm, mk, kj -> j", dg*(1+A*1j), CMx[:, :, :, 0], (Ef[:] * Nf[:, None] ), np.exp( 1j* (OMEGA[:, None]-OMEGA) *tijd))
dEdt += - 0.5*(pd-g)*Ef + fbr*Ef2 + Kinj*EAinj*(1 + np.exp(1j*(u+Vmzm)) )
dNdt = - np.einsum("j, jkm, jk, kj ", g, CMx[:,:,:,0], (Ef*Efcon[:,None]), np.exp( 1j* (OMEGA[:, None]-OMEGA) *tijd))
dNdt += -np.einsum("j, j, jknm, kjm, kj",dg, (1+A*1j), CMx, (Nf[:]*Efcon[:,None]*Ef[:,None,None]), np.exp( 1j* (OMEGA[:, None]-OMEGA) *tijd) )
dNdt += JN - Nf*ed
return dNdt
你能否提出更多技巧来加快速度?
答案 0 :(得分:1)
我无法从您的代码中看到为什么这不可矢量化。矢量化可以将这种Python代码加速大约100倍。不知道相对于jit它是如何做的。
例如,看起来你可以将你的dEdt从循环中取出,然后用一个步骤计算它:dEdt = 0.5 * (Cmx[:, :, :, 0] * dg * (1+A*1j) * Nf[:] * Ef[:, None] * np.exp( 1j* (Sweep.omega[None, :, None, None]-Sweep.omega) *tijd)).sum(axis=2).sum(axis=1) - 0.5*(pd-g)*Ef + fbr*Ef2 + Kinj*EAinj*(1 + np.exp(1j*(u+Vmzm)) )
(虽然我真的不知道你的Sweet.omega的维度是什么)。
答案 1 :(得分:1)
可能还有其他问题,但其中一个问题是,在模块命名空间中引用数组似乎目前不受支持(下面简单的repro)。尝试导入omega
作为名称。
In [14]: %%file Sweep.py
...: import numpy as np
...: constant_val = 0.5
...: constant_arr = np.array([0, 1.5, 2.])
Overwriting Sweep.py
In [15]: Sweep.constant_val
Out[15]: 0.5
In [16]: Sweep.constant_arr
Out[16]: array([ 0. , 1.5, 2. ])
In [17]: @njit
...: def f(value):
...: return value + Sweep.constant_val
...:
In [18]: f(100)
Out[18]: 100.5
In [19]: @njit
...: def f(value):
...: return value + Sweep.constant_arr[0]
In [20]: f(100)
LoweringError: Failed at nopython (nopython mode backend)
'NoneType' object has no attribute 'module'
File "<ipython-input-19-0a259ade6b9e>", line 3
[1] During: lowering "$0.3 = getattr(value=$0.2, attr=constant_arr)" at <ipython-input-19-0a259ade6b9e> (3)