我的程序评估求解线性微分方程时的误差。它仅使用numpy数组。当我尝试将numba的jit装饰器用于定义的功能时,我只会遇到错误。你能帮我正确使用吗?
我的代码:
import numpy as np
from numba import jit
def rk4(t_prev, x_prev, derivs, dt):
k1 = dt * derivs(t_prev, x_prev)
k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
k4 = dt * derivs(t_prev + dt, x_prev + k3)
x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
return x_next
global k, x_0, v_0, t_0, t_f
k = 1
x_0 = 0
v_0 = np.sqrt(k)
t_0 = 0
t_f = 10
dtList = np.logspace(0, -5, 1000)
def derivs(t, X):
deriv = np.zeros([2])
deriv[0] = X[1]
deriv[1] = -k * X[0]
return deriv
def err(dt):
tList = np.arange(t_0, t_f + dt, dt)
N = tList.shape[0]
XList = np.zeros([N,2])
XList[0][0], XList[0][1] = x_0, v_0
for i in range(N-1):
XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
error = np.abs(XList[-1][0] - np.sin(10))
return error
print(err(.001))
答案 0 :(得分:1)
以下对我有用:
import numpy as np
from numba import jit
@jit(nopython=True)
def rk4(t_prev, x_prev, derivs, dt):
k1 = dt * derivs(t_prev, x_prev)
k2 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k1)
k3 = dt * derivs(t_prev + 1/2*dt, x_prev + 1/2*k2)
k4 = dt * derivs(t_prev + dt, x_prev + k3)
x_next = x_prev + 1/6*k1 + 1/3*k2 + 1/3*k3 + 1/6*k4
return x_next
global k, x_0, v_0, t_0, t_f
k = 1
x_0 = 0
v_0 = np.sqrt(k)
t_0 = 0
t_f = 10
dtList = np.logspace(0, -5, 1000)
@jit(nopython=True)
def derivs(t, X):
deriv = np.zeros(2)
deriv[0] = X[1]
deriv[1] = -k * X[0]
return deriv
@jit(nopython=True)
def err(dt):
tList = np.arange(t_0, t_f + dt, dt)
N = tList.shape[0]
XList = np.zeros((N,2))
XList[0][0], XList[0][1] = x_0, v_0
for i in range(N-1):
XList[i+1] = rk4(tList[i], XList[i], derivs, dt)
error = np.abs(XList[-1][0] - np.sin(10))
return error
print(err(.001))
请注意,我对您的代码所做的仅有两项更改是替换了对np.zeros
的调用,这些调用将传入列表传递给2d情况下的tuple
或1d中的裸整数案件。请参见以下问题,以了解其原因: