是否可以将类方法引用传递给njit函数?

时间:2018-03-27 07:12:51

标签: python class reference numba

我试图改善一些代码的计算时间。所以我使用numba模块的import numpy as np from numba import jitclass, jit, njit from numba import int32, float64 import matplotlib.pyplot as plt import time spec = [('V_init' ,float64), ('a' ,float64), ('b' ,float64), ('g',float64), ('dt' ,float64), ('NbODEs',int32), ('dydx' ,float64[:]), ('time' ,float64[:]), ('V' ,float64[:]), ('W' ,float64[:]), ('y' ,float64[:]) ] @jitclass(spec, ) class FHNfunc: def __init__(self,): self.V_init = .04 self.a= 0.25 self.b=0.001 self.g = 0.003 self.dt = .01 self.NbODEs = 2 self.dydx =np.zeros(self.NbODEs ) self.y =np.zeros(self.NbODEs ) def Eul(self,): self.deriv() self.y += (self.dydx * self.dt) def deriv(self,): self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1] self.dydx[1]= self.b * self.y[0] - self.g * self.y[1] return @njit(fastmath=True) def solve1(FH1,FHEuler,tp): V = np.zeros(len(tp), ) W = np.zeros(len(tp), ) for idx, t in enumerate(tp): FHEuler V[idx] = FH1.y[0] W[idx] = FH1.y[1] return V,W if __name__ == "__main__": FH1 = FHNfunc() FHEuler = FH1.Eul dt = .01 tp = np.linspace(0, 1000, num = int((1000)/dt)) t0 = time.time() [V1,W1] = solve1(FH1,FHEuler,tp) print(time.time()- t0) plt.figure() plt.plot(tp,V1) plt.plot(tp,W1) plt.show() 装饰器来做到这一点。在这个例子中:

FHEuler = FH1.Eul

我想传递一个名为This error may have been caused by the following argument(s): - argument 1: cannot determine Numba type of <class 'method'> 的类方法的引用,但它崩溃并给我这个错误

{{1}}

那么可以将引用传递给njit函数吗?或者它是否存在变通方法?

1 个答案:

答案 0 :(得分:2)

Numba无法将该函数作为参数处理。另一种方法是在编译函数之前和之后使用内部函数来处理其他参数并返回内部函数,并在其中运行已编译的输入函数。请试试这个:

def solve1(FH1,FHEuler,tp):
    FHEuler_f = njit(FHEuler)
    @njit(fastmath=True)
    def inner(FH1_x, tp_x):
        V = np.zeros(len(tp_x), )
        W = np.zeros(len(tp_x), )
        for idx, t in enumerate(tp_x):
            FHEuler_f
            V[idx] = FH1_x.y[0]
            W[idx] = FH1_x.y[1]
        return V,W
    return inner(FH1, tp)

可能没有必要传递功能。这个看起来很有效

@njit(fastmath=True)
def solve1(FH1,tp):
    FHEuler = FH1.Eul
    V = np.zeros(len(tp), )
    W = np.zeros(len(tp), )

    for idx, t in enumerate(tp):
        FHEuler()
        V[idx] = FH1.y[0]
        W[idx] = FH1.y[1]
    return V,W