如何在仅包含Numpy数组的程序中使用numba的jit?

时间:2019-10-07 08:06:15

标签: python numba

我的程序评估求解线性微分方程时的误差。它仅使用numpy数组。当我尝试将numba的jit装饰器用于定义的功能时,我只会遇到错误。你能帮我正确使用吗?

我的代码:

import numpy as np
from numba import jit

def rk4(t_prev, x_prev, derivs, dt):
    k1 = dt * derivs(t_prev, x_prev)
    k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
    k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
    k4 = dt * derivs(t_prev + dt, x_prev + k3)
    x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
    return x_next

global k, x_0, v_0, t_0, t_f

k = 1

x_0 = 0
v_0 = np.sqrt(k)

t_0 = 0
t_f = 10

dtList = np.logspace(0, -5, 1000)


def derivs(t, X):
    deriv = np.zeros([2])
    deriv[0] = X[1]
    deriv[1] = -k * X[0]
    return deriv


def err(dt):
    tList = np.arange(t_0, t_f + dt, dt)
    N = tList.shape[0]
    XList = np.zeros([N,2])
    XList[0][0], XList[0][1] = x_0, v_0
    for i in range(N-1):
        XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
    error = np.abs(XList[-1][0] - np.sin(10))
    return error

print(err(.001))

1 个答案:

答案 0 :(得分:1)

以下对我有用:

import numpy as np
from numba import jit

@jit(nopython=True)
def rk4(t_prev, x_prev, derivs, dt):
    k1 = dt * derivs(t_prev, x_prev)
    k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
    k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
    k4 = dt * derivs(t_prev + dt, x_prev + k3)
    x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
    return x_next

global k, x_0, v_0, t_0, t_f

k = 1

x_0 = 0
v_0 = np.sqrt(k)

t_0 = 0
t_f = 10

dtList = np.logspace(0, -5, 1000)

@jit(nopython=True)
def derivs(t, X):
    deriv = np.zeros(2)
    deriv[0] = X[1]
    deriv[1] = -k * X[0]
    return deriv


@jit(nopython=True)
def err(dt):
    tList = np.arange(t_0, t_f + dt, dt)
    N = tList.shape[0]
    XList = np.zeros((N,2))
    XList[0][0], XList[0][1] = x_0, v_0
    for i in range(N-1):
        XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
    error = np.abs(XList[-1][0] - np.sin(10))
    return error

print(err(.001))

请注意,我对您的代码所做的仅有两项更改是替换了对np.zeros的调用,这些调用将传入列表传递给2d情况下的tuple或1d中的裸整数案件。请参见以下问题,以了解其原因:

https://github.com/numba/numba/issues/3993