改善ODE求解的时间计算

时间:2018-03-14 15:49:03

标签: python performance

我有一个ODE来解决哪个是用于心脏细胞建模的FitzHugh Nagumo方程。我制作了一个使用Euler方法解码两个ODE的代码。所以我有这个:

import numpy as np
from numba import jitclass
from numba import int32, float64
import matplotlib.pyplot as plt
import time

spec = [('V_init' ,float64),
        ('a' ,float64),
        ('b' ,float64),
        ('g',float64),
        ('dt' ,float64),
        ('NbODEs',int32),
        ('dydx' ,float64[:]),
        ('y'    ,float64[:]) ]

@jitclass(spec, )
class FHNfunc:
    def __init__(self,):
        self.V_init = .04
        self.a= 0.25
        self.b=0.001
        self.g = 0.003
        self.dt = .01
        self.NbODEs = 2
        self.dydx    =np.zeros(self.NbODEs, )
        self.y    =np.zeros(self.NbODEs, )

    def Eul(self):
        self.deriv()
        self.y += (self.dydx * self.dt)

    def deriv(self , ):
        self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
        self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
        return


FH = FHNfunc()
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))

V = np.zeros(len(tp), )
W = np.zeros(len(tp), )

t0 = time.time()
for idx, t in enumerate(tp):
    FH.Eul()
    V[idx] = FH.y[0]
    W[idx] = FH.y[1]

print(time.time()- t0)

plt.subplots()
plt.plot(tp,V)
plt.plot(tp,W)
plt.show()

我尝试使用numba jitclass来提高FHN ODE解决时间的性能,但是我没想到会有所帮助。 对于那个例子,代码给了我11.44s没有使用jitclass(当我评论@jitclass(spec, )时)和6.14s使用jitclass。我并没有抱怨获得两倍的计算时间,但我期待更多。我知道我可以在课堂中集成for循环,但我需要它在外面。 所以我正在寻求解决方案,以便为这个例子提供更多的计算时间。

编辑:这次尝试用jit在类外部实现ODE函数:

__author__ = 'Maxime'
import numpy as np
from numba import jitclass, jit
from numba import int32, float64
import matplotlib.pyplot as plt
import time

spec = [('V_init' ,float64),
        ('a' ,float64),
        ('b' ,float64),
        ('g',float64),
        ('dt' ,float64),
        ('NbODEs',int32),
        ('dydx' ,float64[:]),
        ('time' ,float64[:]),
        ('V' ,float64[:]),
        ('W' ,float64[:]),
        ('y'    ,float64[:]) ]

# @jitclass(spec, )
class FHNfunc:
    def __init__(self,):
        self.V_init = .04
        self.a= 0.25
        self.b=0.001
        self.g = 0.003
        self.dt = .001
        self.NbODEs = 2
        self.dydx    =np.zeros(self.NbODEs  )
        self.y    =np.zeros(self.NbODEs  )

    def Eul(self):
        self.deriv()
        self.y += (self.dydx * self.dt)


    def deriv(self):
        # self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
        # self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
        self.dydx[0]= fV(self.V_init,self.y[0],self.y[1],self.a)
        self.dydx[1]= fW(self.y[0],self.y[1],self.b,self.g)
        return


@jit(float64(float64, float64, float64, float64))
def fV(V_init,y0,y1,a):
    return V_init - y0 *(a-(y0))*(1-(y0))-y1


@jit(float64(float64, float64, float64, float64))
def fW(y0,y1,b,g):
    return b * y0 - g * y1


FH = FHNfunc()
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))

V = np.zeros(len(tp), )
W = np.zeros(len(tp), )

t0 = time.time()
for idx, t in enumerate(tp):
    FH.Eul()
    V[idx] = FH.y[0]
    W[idx] = FH.y[1]
print(time.time()- t0)
plt.subplots()
plt.plot(tp,V)
plt.plot(tp,W)
plt.show()

但在这种情况下,我根本没有时间改进:11.4s。

为什么我不能在类

中使用积分循环

当我有几个模型并且我想在它们之间进行耦合时,我需要在FHN实例之间传递变量。例如:

__author__ = 'Maxime'
import numpy as np
from numba import jitclass, jit, njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time

spec = [('V_init' ,float64),
        ('a' ,float64),
        ('b' ,float64),
        ('g',float64),
        ('dt' ,float64),
        ('NbODEs',int32),
        ('dydx' ,float64[:]),
        ('time' ,float64[:]),
        ('V' ,float64[:]),
        ('W' ,float64[:]),
        ('y'    ,float64[:]) ]

@jitclass(spec, )
class FHNfunc:
    def __init__(self,):
        self.V_init = .04
        self.a= 0.25
        self.b=0.001
        self.g = 0.003
        self.dt = .001
        self.NbODEs = 2
        self.dydx    =np.zeros(self.NbODEs  )
        self.y    =np.zeros(self.NbODEs  )


    def Eul(self):
        self.deriv()
        self.y += (self.dydx * self.dt)


    def deriv(self):
        self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
        self.dydx[1]= self.b * self.y[0] - self.g * self.y[1] 
        return



FH1 = FHNfunc()
FH2 = FHNfunc()
FH2.V_init=0.
dt = .001
tp = np.linspace(0, 1000, num = int((1000)/dt))

V1 = np.zeros(len(tp), )
V2 = np.zeros(len(tp), )
W1 = np.zeros(len(tp), )
W2 = np.zeros(len(tp), )

t0 = time.time()
for idx, t in enumerate(tp):
    FH1.Eul()
    FH2.V_init=FH1.V_init
    FH2.Eul()
    V1[idx] = FH1.y[0]
    W1[idx] = FH1.y[1]
    V2[idx] = FH2.y[0]
    W2[idx] = FH2.y[1]

print(time.time()- t0)
plt.figure
plt.subplot(211)
plt.plot(tp,V1)
plt.plot(tp,W1)
plt.subplot(212)
plt.plot(tp,V2)
plt.plot(tp,W2)
plt.show()

在这种情况下,我不知道如何在实例之间传递变量使用numpy。此外,对于这个例子,所有的实例都属于同一类,但在我的完整模型中,我有8个不同的类来表示属于系统的不同类型的模型。

@ max9111的答案

所以我用njit测试它,两个神经元连接在一起,效果很好:

__author__ = 'Maxime'
import numpy as np
from numba import jitclass, jit, njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time

spec = [('V_init' ,float64),
        ('a' ,float64),
        ('b' ,float64),
        ('g',float64),
        ('dt' ,float64),
        ('NbODEs',int32),
        ('dydx' ,float64[:]),
        ('time' ,float64[:]),
        ('V' ,float64[:]),
        ('W' ,float64[:]),
        ('y'    ,float64[:]) ]

@jitclass(spec, )
class FHNfunc:
    def __init__(self,):
        self.V_init = .04
        self.a= 0.25
        self.b=0.001
        self.g = 0.003
        self.dt = .001
        self.NbODEs = 2
        self.dydx    =np.zeros(self.NbODEs  )
        self.y    =np.zeros(self.NbODEs  )

    def Eul(self,):
        self.deriv()
        self.y += (self.dydx * self.dt) 

    def deriv(self,):
        self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
        self.dydx[1]= self.b * self.y[0] - self.g * self.y[1] 
        return


@njit(fastmath=True)
def solve2(FH1,FH2,tp):
    V1 = np.zeros(len(tp), )
    V2 = np.zeros(len(tp), )
    W1 = np.zeros(len(tp), )
    W2 = np.zeros(len(tp), )

    for idx, t in enumerate(tp):
        FH1.Eul()
        FH2.V_init=FH1.V_init
        FH2.Eul()
        V1[idx] = FH1.y[0]
        W1[idx] = FH1.y[1]
        V2[idx] = FH2.y[0]
        W2[idx] = FH2.y[1]

    return V1,W1,V2,W2

if __name__ == "__main__":
    #with njit and jiclass
    FH1 = FHNfunc()
    FH2 = FHNfunc()
    FH2.V_init=0.
    dt = .001
    tp = np.linspace(0, 1000, num = int((1000)/dt))

    t0 = time.time()
    [V1,W1,V2,W2] = solve2(FH1,FH2,tp)
    print(time.time()- t0)
    plt.figure()
    plt.subplot(211)
    plt.plot(tp,V1)
    plt.plot(tp,W1)
    plt.subplot(212)
    plt.plot(tp,V2)
    plt.plot(tp,W2) 

    #with jitclass only
    FH1 = FHNfunc()
    FH2 = FHNfunc()
    FH2.V_init=0.
    dt = .001
    tp = np.linspace(0, 1000, num = int((1000)/dt))

    t0 = time.time()
    V1 = np.zeros(len(tp), )
    V2 = np.zeros(len(tp), )
    W1 = np.zeros(len(tp), )
    W2 = np.zeros(len(tp), )

    for idx, t in enumerate(tp):
        FH1.Eul()
        FH2.V_init=FH1.V_init
        FH2.Eul()
        V1[idx] = FH1.y[0]
        W1[idx] = FH1.y[1]
        V2[idx] = FH2.y[0]
        W2[idx] = FH2.y[1]
    print(time.time()- t0)
    plt.figure()
    plt.subplot(211)
    plt.plot(tp,V1)
    plt.plot(tp,W1)
    plt.subplot(212)
    plt.plot(tp,V2)
    plt.plot(tp,W2)
    plt.show()

有了这个我有1.8秒的所有优化(njit& jitclass)和两个模型实例。我只有12.4s的jitclass和21.7s没有numba。所以12倍,一点都不差。 感谢@ max9111的解决方案。

1 个答案:

答案 0 :(得分:1)

关于函数内联和LLVM优化的全部内容

所有函数都非常原始(关于计算时间)。所以numba唯一可以做的就是内联这些函数并缓存已编译的函数,以避免下次调用时的编译开销。

你的Jitclass Benchmark有一个主要问题。您正在从非编译代码中调用1000000次原始函数。 (意思是至少1000000函数调用)。这应该是:

Example_1使用Jitclass

import numpy as np
from numba import jitclass,njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time

spec = [('V_init' ,float64),
        ('a' ,float64),
        ('b' ,float64),
        ('g',float64),
        ('dt' ,float64),
        ('NbODEs',int32),
        ('dydx' ,float64[:]),
        ('y'    ,float64[:]) ]

@jitclass(spec)
class FHNfunc:
    def __init__(self,):
        self.V_init = .04
        self.a= 0.25
        self.b=0.001
        self.g = 0.003
        self.dt = .001
        self.NbODEs = 2
        self.dydx    =np.zeros(self.NbODEs, )
        self.y    =np.zeros(self.NbODEs, )

    def Eul(self):
        self.deriv()
        self.y += (self.dydx * self.dt)

    def deriv(self , ):
        self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
        self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
        return

@njit(fastmath=True)
def solve(FH,dt,tp):
  V = np.zeros(len(tp), )
  W = np.zeros(len(tp), )

  for idx, t in enumerate(tp):
      FH.Eul()
      V[idx] = FH.y[0]
      W[idx] = FH.y[1]

  return V,W

if __name__ == "__main__":
  FH = FHNfunc()
  dt = .001
  tp = np.linspace(0, 1000, num = int((1000)/dt))

  t1=time.time()
  [V,W]=solve(FH,dt,tp)
  print(time.time()-t1)

  plt.subplots()
  plt.plot(tp,V)
  plt.plot(tp,W)
  plt.show()

这给出了大约0.4s的运行时间。

Example_2和3

import numpy as np
import numba as nb
import matplotlib.pyplot as plt
import time

@nb.njit(fastmath=True,cache=True)
def Eul(V_init,y,a,g,dt,dydx):
  deriv(V_init,y,a,b,g,dydx)
  y += (dydx * dt)

@nb.njit(fastmath=True,cache=True)
def deriv(V_init,y,a,b,g,dydx):
  dydx[0]= fV(V_init,y[0],y[1],a)
  dydx[1]= fW(y[0],y[1],b,g)

@nb.njit(fastmath=True,cache=True)
def fV(V_init,y0,y1,a):
  return V_init - y0 *(a-(y0))*(1-(y0))-y1

@nb.njit(fastmath=True,cache=True)
def fW(y0,y1,b,g):
  return b * y0 - g * y1

@nb.njit(fastmath=True,cache=True)
def solving_1(V_init,y,a,g,dt,tp):
  V = np.empty(tp.shape[0],dtype=y.dtype)
  W = np.empty(tp.shape[0],dtype=y.dtype)
  dydx=np.empty(2,dtype=np.float64)

  for idx, t in enumerate(tp):
    Eul(V_init,y,a,g,dt,dydx)
    V[idx] = y[0]
    W[idx] = y[1]

  return V,W

@nb.njit(fastmath=True,cache=True)
def solving_2(V_init,y,a,g,dt,tp):
  V = np.empty(tp.shape[0],dtype=y.dtype)
  W = np.empty(tp.shape[0],dtype=y.dtype)
  dydx=np.empty(2,dtype=y.dtype)

  for idx, t in enumerate(tp):
    dydx[0]=V_init - y[0] *(a-(y[0]))*(1-(y[0]))-y[1]
    dydx[1]=b * y[0] - g * y[1]
    y[0] += (dydx[0] * dt)
    y[1] += (dydx[1] * dt)

    V[idx] = y[0]
    W[idx] = y[1]

  return V,W

if __name__ == "__main__":
  V_init = .04
  a= 0.25
  b=0.001
  g = 0.003
  dt = .001
  dt = .001
  tp = np.linspace(0, 1000, num = int((1000)/dt))

  y=np.zeros(2,dtype=np.float64)

  t1=time.time()
  [V,W]=solving_2(V_init,y,a,g,dt,tp)
  print(time.time()-t1)

  plt.subplots()
  plt.plot(tp,V)
  plt.plot(tp,W)
  plt.show()

我在这里测试了两种变体。所有工作都在一个功能中,并分成几个功能。这为solve_1提供0.17s,为solve_2提供0.06s。

我并不感到惊讶,jitclass approuch有点慢(缓存不支持,相当新的功能),但我没想到在solve_1和solve_2方法中看到性能因素2,这是偶数如果有人使用了一些内存复制,那么这个内存也没有被优化掉。