我有一个ODE来解决哪个是用于心脏细胞建模的FitzHugh Nagumo方程。我制作了一个使用Euler方法解码两个ODE的代码。所以我有这个:
import numpy as np
from numba import jitclass
from numba import int32, float64
import matplotlib.pyplot as plt
import time
spec = [('V_init' ,float64),
('a' ,float64),
('b' ,float64),
('g',float64),
('dt' ,float64),
('NbODEs',int32),
('dydx' ,float64[:]),
('y' ,float64[:]) ]
@jitclass(spec, )
class FHNfunc:
def __init__(self,):
self.V_init = .04
self.a= 0.25
self.b=0.001
self.g = 0.003
self.dt = .01
self.NbODEs = 2
self.dydx =np.zeros(self.NbODEs, )
self.y =np.zeros(self.NbODEs, )
def Eul(self):
self.deriv()
self.y += (self.dydx * self.dt)
def deriv(self , ):
self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
return
FH = FHNfunc()
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))
V = np.zeros(len(tp), )
W = np.zeros(len(tp), )
t0 = time.time()
for idx, t in enumerate(tp):
FH.Eul()
V[idx] = FH.y[0]
W[idx] = FH.y[1]
print(time.time()- t0)
plt.subplots()
plt.plot(tp,V)
plt.plot(tp,W)
plt.show()
我尝试使用numba jitclass
来提高FHN ODE解决时间的性能,但是我没想到会有所帮助。
对于那个例子,代码给了我11.44s没有使用jitclass(当我评论@jitclass(spec, )
时)和6.14s使用jitclass。我并没有抱怨获得两倍的计算时间,但我期待更多。我知道我可以在课堂中集成for循环,但我需要它在外面。
所以我正在寻求解决方案,以便为这个例子提供更多的计算时间。
编辑:这次尝试用jit在类外部实现ODE函数:
__author__ = 'Maxime'
import numpy as np
from numba import jitclass, jit
from numba import int32, float64
import matplotlib.pyplot as plt
import time
spec = [('V_init' ,float64),
('a' ,float64),
('b' ,float64),
('g',float64),
('dt' ,float64),
('NbODEs',int32),
('dydx' ,float64[:]),
('time' ,float64[:]),
('V' ,float64[:]),
('W' ,float64[:]),
('y' ,float64[:]) ]
# @jitclass(spec, )
class FHNfunc:
def __init__(self,):
self.V_init = .04
self.a= 0.25
self.b=0.001
self.g = 0.003
self.dt = .001
self.NbODEs = 2
self.dydx =np.zeros(self.NbODEs )
self.y =np.zeros(self.NbODEs )
def Eul(self):
self.deriv()
self.y += (self.dydx * self.dt)
def deriv(self):
# self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
# self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
self.dydx[0]= fV(self.V_init,self.y[0],self.y[1],self.a)
self.dydx[1]= fW(self.y[0],self.y[1],self.b,self.g)
return
@jit(float64(float64, float64, float64, float64))
def fV(V_init,y0,y1,a):
return V_init - y0 *(a-(y0))*(1-(y0))-y1
@jit(float64(float64, float64, float64, float64))
def fW(y0,y1,b,g):
return b * y0 - g * y1
FH = FHNfunc()
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))
V = np.zeros(len(tp), )
W = np.zeros(len(tp), )
t0 = time.time()
for idx, t in enumerate(tp):
FH.Eul()
V[idx] = FH.y[0]
W[idx] = FH.y[1]
print(time.time()- t0)
plt.subplots()
plt.plot(tp,V)
plt.plot(tp,W)
plt.show()
但在这种情况下,我根本没有时间改进:11.4s。
当我有几个模型并且我想在它们之间进行耦合时,我需要在FHN实例之间传递变量。例如:
__author__ = 'Maxime'
import numpy as np
from numba import jitclass, jit, njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time
spec = [('V_init' ,float64),
('a' ,float64),
('b' ,float64),
('g',float64),
('dt' ,float64),
('NbODEs',int32),
('dydx' ,float64[:]),
('time' ,float64[:]),
('V' ,float64[:]),
('W' ,float64[:]),
('y' ,float64[:]) ]
@jitclass(spec, )
class FHNfunc:
def __init__(self,):
self.V_init = .04
self.a= 0.25
self.b=0.001
self.g = 0.003
self.dt = .001
self.NbODEs = 2
self.dydx =np.zeros(self.NbODEs )
self.y =np.zeros(self.NbODEs )
def Eul(self):
self.deriv()
self.y += (self.dydx * self.dt)
def deriv(self):
self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
return
FH1 = FHNfunc()
FH2 = FHNfunc()
FH2.V_init=0.
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))
V1 = np.zeros(len(tp), )
V2 = np.zeros(len(tp), )
W1 = np.zeros(len(tp), )
W2 = np.zeros(len(tp), )
t0 = time.time()
for idx, t in enumerate(tp):
FH1.Eul()
FH2.V_init=FH1.V_init
FH2.Eul()
V1[idx] = FH1.y[0]
W1[idx] = FH1.y[1]
V2[idx] = FH2.y[0]
W2[idx] = FH2.y[1]
print(time.time()- t0)
plt.figure
plt.subplot(211)
plt.plot(tp,V1)
plt.plot(tp,W1)
plt.subplot(212)
plt.plot(tp,V2)
plt.plot(tp,W2)
plt.show()
在这种情况下,我不知道如何在实例之间传递变量使用numpy。此外,对于这个例子,所有的实例都属于同一类,但在我的完整模型中,我有8个不同的类来表示属于系统的不同类型的模型。
所以我用njit测试它,两个神经元连接在一起,效果很好:
__author__ = 'Maxime'
import numpy as np
from numba import jitclass, jit, njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time
spec = [('V_init' ,float64),
('a' ,float64),
('b' ,float64),
('g',float64),
('dt' ,float64),
('NbODEs',int32),
('dydx' ,float64[:]),
('time' ,float64[:]),
('V' ,float64[:]),
('W' ,float64[:]),
('y' ,float64[:]) ]
@jitclass(spec, )
class FHNfunc:
def __init__(self,):
self.V_init = .04
self.a= 0.25
self.b=0.001
self.g = 0.003
self.dt = .001
self.NbODEs = 2
self.dydx =np.zeros(self.NbODEs )
self.y =np.zeros(self.NbODEs )
def Eul(self,):
self.deriv()
self.y += (self.dydx * self.dt)
def deriv(self,):
self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
return
@njit(fastmath=True)
def solve2(FH1,FH2,tp):
V1 = np.zeros(len(tp), )
V2 = np.zeros(len(tp), )
W1 = np.zeros(len(tp), )
W2 = np.zeros(len(tp), )
for idx, t in enumerate(tp):
FH1.Eul()
FH2.V_init=FH1.V_init
FH2.Eul()
V1[idx] = FH1.y[0]
W1[idx] = FH1.y[1]
V2[idx] = FH2.y[0]
W2[idx] = FH2.y[1]
return V1,W1,V2,W2
if __name__ == "__main__":
#with njit and jiclass
FH1 = FHNfunc()
FH2 = FHNfunc()
FH2.V_init=0.
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))
t0 = time.time()
[V1,W1,V2,W2] = solve2(FH1,FH2,tp)
print(time.time()- t0)
plt.figure()
plt.subplot(211)
plt.plot(tp,V1)
plt.plot(tp,W1)
plt.subplot(212)
plt.plot(tp,V2)
plt.plot(tp,W2)
#with jitclass only
FH1 = FHNfunc()
FH2 = FHNfunc()
FH2.V_init=0.
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))
t0 = time.time()
V1 = np.zeros(len(tp), )
V2 = np.zeros(len(tp), )
W1 = np.zeros(len(tp), )
W2 = np.zeros(len(tp), )
for idx, t in enumerate(tp):
FH1.Eul()
FH2.V_init=FH1.V_init
FH2.Eul()
V1[idx] = FH1.y[0]
W1[idx] = FH1.y[1]
V2[idx] = FH2.y[0]
W2[idx] = FH2.y[1]
print(time.time()- t0)
plt.figure()
plt.subplot(211)
plt.plot(tp,V1)
plt.plot(tp,W1)
plt.subplot(212)
plt.plot(tp,V2)
plt.plot(tp,W2)
plt.show()
有了这个我有1.8秒的所有优化(njit& jitclass)和两个模型实例。我只有12.4s的jitclass和21.7s没有numba。所以12倍,一点都不差。 感谢@ max9111的解决方案。
答案 0 :(得分:1)
关于函数内联和LLVM优化的全部内容
所有函数都非常原始(关于计算时间)。所以numba唯一可以做的就是内联这些函数并缓存已编译的函数,以避免下次调用时的编译开销。
你的Jitclass Benchmark有一个主要问题。您正在从非编译代码中调用1000000次原始函数。 (意思是至少1000000函数调用)。这应该是:
Example_1使用Jitclass
import numpy as np
from numba import jitclass,njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time
spec = [('V_init' ,float64),
('a' ,float64),
('b' ,float64),
('g',float64),
('dt' ,float64),
('NbODEs',int32),
('dydx' ,float64[:]),
('y' ,float64[:]) ]
@jitclass(spec)
class FHNfunc:
def __init__(self,):
self.V_init = .04
self.a= 0.25
self.b=0.001
self.g = 0.003
self.dt = .001
self.NbODEs = 2
self.dydx =np.zeros(self.NbODEs, )
self.y =np.zeros(self.NbODEs, )
def Eul(self):
self.deriv()
self.y += (self.dydx * self.dt)
def deriv(self , ):
self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
return
@njit(fastmath=True)
def solve(FH,dt,tp):
V = np.zeros(len(tp), )
W = np.zeros(len(tp), )
for idx, t in enumerate(tp):
FH.Eul()
V[idx] = FH.y[0]
W[idx] = FH.y[1]
return V,W
if __name__ == "__main__":
FH = FHNfunc()
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))
t1=time.time()
[V,W]=solve(FH,dt,tp)
print(time.time()-t1)
plt.subplots()
plt.plot(tp,V)
plt.plot(tp,W)
plt.show()
这给出了大约0.4s的运行时间。
Example_2和3
import numpy as np
import numba as nb
import matplotlib.pyplot as plt
import time
@nb.njit(fastmath=True,cache=True)
def Eul(V_init,y,a,g,dt,dydx):
deriv(V_init,y,a,b,g,dydx)
y += (dydx * dt)
@nb.njit(fastmath=True,cache=True)
def deriv(V_init,y,a,b,g,dydx):
dydx[0]= fV(V_init,y[0],y[1],a)
dydx[1]= fW(y[0],y[1],b,g)
@nb.njit(fastmath=True,cache=True)
def fV(V_init,y0,y1,a):
return V_init - y0 *(a-(y0))*(1-(y0))-y1
@nb.njit(fastmath=True,cache=True)
def fW(y0,y1,b,g):
return b * y0 - g * y1
@nb.njit(fastmath=True,cache=True)
def solving_1(V_init,y,a,g,dt,tp):
V = np.empty(tp.shape[0],dtype=y.dtype)
W = np.empty(tp.shape[0],dtype=y.dtype)
dydx=np.empty(2,dtype=np.float64)
for idx, t in enumerate(tp):
Eul(V_init,y,a,g,dt,dydx)
V[idx] = y[0]
W[idx] = y[1]
return V,W
@nb.njit(fastmath=True,cache=True)
def solving_2(V_init,y,a,g,dt,tp):
V = np.empty(tp.shape[0],dtype=y.dtype)
W = np.empty(tp.shape[0],dtype=y.dtype)
dydx=np.empty(2,dtype=y.dtype)
for idx, t in enumerate(tp):
dydx[0]=V_init - y[0] *(a-(y[0]))*(1-(y[0]))-y[1]
dydx[1]=b * y[0] - g * y[1]
y[0] += (dydx[0] * dt)
y[1] += (dydx[1] * dt)
V[idx] = y[0]
W[idx] = y[1]
return V,W
if __name__ == "__main__":
V_init = .04
a= 0.25
b=0.001
g = 0.003
dt = .001
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))
y=np.zeros(2,dtype=np.float64)
t1=time.time()
[V,W]=solving_2(V_init,y,a,g,dt,tp)
print(time.time()-t1)
plt.subplots()
plt.plot(tp,V)
plt.plot(tp,W)
plt.show()
我在这里测试了两种变体。所有工作都在一个功能中,并分成几个功能。这为solve_1提供0.17s,为solve_2提供0.06s。
我并不感到惊讶,jitclass approuch有点慢(缓存不支持,相当新的功能),但我没想到在solve_1和solve_2方法中看到性能因素2,这是偶数如果有人使用了一些内存复制,那么这个内存也没有被优化掉。