如何在路径关闭时让SciPy.integrate.odeint停止?

时间:2015-10-12 02:57:52

标签: python scipy odeint

下面的脚本在封闭路径周围集成磁场线,并在使用Python中的Runge-Kutta RK4在某个容差范围内恢复原始值时停止。我想使用SciPy.integrate.odeint,但我看不出当路径大概关闭时我怎么能告诉它停止。

当然odeint可能比在Python中集成要快得多,我可以让它盲目地绕过并在结果中寻找结束,但在将来我会做更大的问题。< / p>

有没有办法可以实现&#34; 确定已经足够接近 - 你现在可以停下来了!&#34;进入odeint的方法?或者我应该整合一段时间,检查,整合更多,检查......

discussion似乎相关,并且似乎表明&#34;您无法从SciPy&#34;可能就是答案。

注意:我通常使用RK45(Runge-Kutta-Fehlberg),它在给定的行程尺寸下更准确,以加快速度,但我在这里保持简单。它还可以实现可变步长。

更新:但有时我需要固定的步长。我发现Scipy.integrate.ode确实提供了一种测试/停止方法ode.solout(t, y),但似乎没有能力在t的固定点进行评估。 odeint允许在t的固定点进行评估,但似乎没有测试/停止方法。

enter image description here

def rk4Bds_stops(x, h, n, F, fclose=0.1):

    h_over_two, h_over_six = h/2.0, h/6.0

    watching = False
    distance_max = 0.0
    distance_old = -1.0

    i = 0

    while i < n and not (watching and greater):

        k1 = F( x[i]                )
        k2 = F( x[i] + k1*h_over_two)
        k3 = F( x[i] + k2*h_over_two)
        k4 = F( x[i] + k3*h         )

        x[i+1] = x[i] + h_over_six * (k1 + 2.*(k2 + k3) + k4)

        distance = np.sqrt(((x[i+1] - x[0])**2).sum())
        distance_max = max(distance, distance_max)
        getting_closer = distance < distance_old

        if getting_closer and distance < fclose*distance_max: 
            watching = True

        greater = distance > distance_old
        distance_old = distance

        i += 1

    return i


def get_BrBztanVec(rz):

    Brz = np.zeros(2)

    B_zero = 0.5 * i * mu0 / a
    zz    = rz[1] - h
    alpha = rz[0] / a
    beta  = zz / a
    gamma = zz / rz[0]

    Q = ((1.0 + alpha)**2 + beta**2)
    k = np.sqrt(4. * alpha / Q)

    C1 =    1.0 / (pi * np.sqrt(Q))
    C2 = gamma  / (pi * np.sqrt(Q))
    C3 = (1.0 - alpha**2 - beta**2) / (Q - 4.0*alpha)
    C4 = (1.0 + alpha**2 + beta**2) / (Q - 4.0*alpha)

    E, K = spe.ellipe(k**2), spe.ellipk(k**2)

    Brz[0] += B_zero * C2 * (C4*E - K) 
    Brz[1] += B_zero * C1 * (C3*E + K)

    Bmag = np.sqrt((Brz**2).sum())

    return Brz/Bmag


import numpy as np
import matplotlib.pyplot as plt
import scipy.special as spe
from scipy.integrate import odeint as ODEint

pi = np.pi
mu0 = 4.0 * pi * 1.0E-07

i = 1.0 # amperes
a = 1.0 # meters
h = 0.0 # meters

ds = 0.04  # step distance (meters)

r_list, z_list, n_list = [], [], []
dr_list, dz_list = [], []

r_try = np.linspace(0.15, 0.95, 17)

x = np.zeros((1000, 2))

nsteps = 500

for rt in r_try:

    x[:] = np.nan

    x[0] = np.array([rt, 0.0])

    n = rk4Bds_stops(x, ds, nsteps, get_BrBztanVec)

    n_list.append(n)

    r, z = x[:n+1].T.copy()  # make a copy is necessary

    dr, dz = r[1:] - r[:-1], z[1:] - z[:-1]
    r_list.append(r)
    z_list.append(z)
    dr_list.append(dr)
    dz_list.append(dz)

plt.figure(figsize=[14, 8])
fs = 20

plt.subplot(2,3,1)
for r in r_list:
    plt.plot(r)
plt.title("r", fontsize=fs)

plt.subplot(2,3,2)
for z in z_list:
    plt.plot(z)
plt.title("z", fontsize=fs)

plt.subplot(2,3,3)
for r, z in zip(r_list, z_list):
    plt.plot(r, z)
plt.title("r, z", fontsize=fs)

plt.subplot(2,3,4)
for dr, dz in zip(dr_list, dz_list):
    plt.plot(dr, dz)
plt.title("dr, dz", fontsize=fs)

plt.subplot(2, 3, 5)
plt.plot(n_list)
plt.title("n", fontsize=fs)

plt.show()

1 个答案:

答案 0 :(得分:1)

您需要的是'事件处理'。 scipy.integrate.odeint无法做到这一点。但是你可以使用日程(见https://pypi.python.org/pypi/python-sundials/0.5),它可以进行事件处理。

将速度保持为优先级的另一个选择是简单地在cython中编写rkf代码。我有一个实现,应该很容易改变,以便在一些标准后停止:

cythoncode.pyx

import numpy as np
cimport numpy as np
import cython
#cython: boundscheck=False
#cython: wraparound=False

cdef double a2  =   2.500000000000000e-01  #  1/4
cdef double a3  =   3.750000000000000e-01  #  3/8
cdef double a4  =   9.230769230769231e-01  #  12/13
cdef double a5  =   1.000000000000000e+00  #  1
cdef double a6  =   5.000000000000000e-01  #  1/2

cdef double b21 =   2.500000000000000e-01  #  1/4
cdef double b31 =   9.375000000000000e-02  #  3/32
cdef double b32 =   2.812500000000000e-01  #  9/32
cdef double b41 =   8.793809740555303e-01  #  1932/2197
cdef double b42 =  -3.277196176604461e+00  # -7200/2197
cdef double b43 =   3.320892125625853e+00  #  7296/2197
cdef double b51 =   2.032407407407407e+00  #  439/216
cdef double b52 =  -8.000000000000000e+00  # -8
cdef double b53 =   7.173489278752436e+00  #  3680/513
cdef double b54 =  -2.058966861598441e-01  # -845/4104
cdef double b61 =  -2.962962962962963e-01  # -8/27
cdef double b62 =   2.000000000000000e+00  #  2
cdef double b63 =  -1.381676413255361e+00  # -3544/2565
cdef double b64 =   4.529727095516569e-01  #  1859/4104
cdef double b65 =  -2.750000000000000e-01  # -11/40

cdef double r1  =   2.777777777777778e-03  #  1/360
cdef double r3  =  -2.994152046783626e-02  # -128/4275
cdef double r4  =  -2.919989367357789e-02  # -2197/75240
cdef double r5  =   2.000000000000000e-02  #  1/50
cdef double r6  =   3.636363636363636e-02  #  2/55

cdef double c1  =   1.157407407407407e-01  #  25/216
cdef double c3  =   5.489278752436647e-01  #  1408/2565
cdef double c4  =   5.353313840155945e-01  #  2197/4104
cdef double c5  =  -2.000000000000000e-01  # -1/5

cdef class cyfunc:
    cdef double dy[2]

    cdef double* f(self,  double* y):    
        return self.dy
    def __cinit__(self):
        pass

@cython.cdivision(True)
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef rkf(cyfunc f, np.ndarray[double, ndim=1] times, 
          np.ndarray[double, ndim=1] x0, 
          double tol=1e-7, double dt_max=-1.0, double dt_min=1e-8):

    # Initialize
    cdef double t = times[0]
    cdef int times_index = 1
    cdef int add = 0
    cdef double end_time = times[len(times) - 1]
    cdef np.ndarray[double, ndim=1] res = np.empty_like(times)
    res[0] = x0[1] # Only storing second variable
    cdef double x[2]
    x[:] = x0

    cdef double k1[2]
    cdef double k2[2]
    cdef double k3[2]
    cdef double k4[2]
    cdef double k5[2]
    cdef double k6[2]
    cdef double r[2]

    while abs(t - times[times_index]) < tol: # if t = 0 multiple times
        res[times_index] = res[0]
        t = times[times_index]
        times_index += 1

    if dt_max == -1.0:
        dt_max = 5. * (times[times_index] - times[0])
    cdef double dt = dt_max/10.0
    cdef double tolh = tol*dt

    while t < end_time:
        # If possible, step to next time to save
        if t + dt >= times[times_index]:
            dt = times[times_index] - t;
            add = 1

        # Calculate Runga Kutta variables
        k1 = f.f(x)
        k1[0] *= dt; k1[1] *= dt; 
        r[0] = x[0] + b21 * k1[0]
        r[1] = x[1] + b21 * k1[1]

        k2 = f.f(r)
        k2[0] *= dt; k2[1] *= dt; 
        r[0] = x[0] + b31 * k1[0] + b32 * k2[0]
        r[1] = x[1] + b31 * k1[1] + b32 * k2[1]

        k3 = f.f(r)
        k3[0] *= dt; k3[1] *= dt; 
        r[0] = x[0] + b41 * k1[0] + b42 * k2[0] + b43 * k3[0]
        r[1] = x[1] + b41 * k1[1] + b42 * k2[1] + b43 * k3[1]

        k4 = f.f(r)
        k4[0] *= dt; k4[1] *= dt; 
        r[0] = x[0] + b51 * k1[0] + b52 * k2[0] + b53 * k3[0] + b54 * k4[0]
        r[1] = x[1] + b51 * k1[1] + b52 * k2[1] + b53 * k3[1] + b54 * k4[1]

        k5 = f.f(r)
        k5[0] *= dt; k5[1] *= dt;
        r[0] = x[0] + b61 * k1[0] + b62 * k2[0] + b63 * k3[0] + b64 * k4[0] + b65 * k5[0]
        r[1] = x[1] + b61 * k1[1] + b62 * k2[1] + b63 * k3[1] + b64 * k4[1] + b65 * k5[1]

        k6 = f.f(r)
        k6[0] *= dt; k6[1] *= dt;

        # Find largest error
        r[0] = abs(r1 * k1[0] + r3 * k3[0] + r4 * k4[0] + r5 * k5[0] + r6 * k6[0])
        r[1] = abs(r1 * k1[1] + r3 * k3[1] + r4 * k4[1] + r5 * k5[1] + r6 * k6[1])
        if r[1] > r[0]:
            r[0] = r[1]

        # If error is smaller than tolerance, take step
        tolh = tol*dt
        if r[0] <= tolh:
            t = t + dt
            x[0] = x[0] + c1 * k1[0] + c3 * k3[0] + c4 * k4[0] + c5 * k5[0]
            x[1] = x[1] + c1 * k1[1] + c3 * k3[1] + c4 * k4[1] + c5 * k5[1]
            # Save if at a save time index
            if add:
                while abs(t - times[times_index]) < tol:
                    res[times_index] = x[1]
                    t = times[times_index]
                    times_index += 1
                add = 0

        # Update time stepping
        dt = dt * min(max(0.84 * ( tolh / r[0] )**0.25, 0.1), 4.0)
        if dt > dt_max:
            dt = dt_max
        elif dt < dt_min:  # Equations are too stiff
            return res*0 - 100 # or something

        # ADD STOPPING CONDITION HERE...

    return res

cdef class F(cyfunc):
    cdef double a

    def __init__(self, double a):
        self.a = a

    cdef double* f(self, double y[2]):
        self.dy[0] = self.a*y[1] - y[0]
        self.dy[1] = y[0] - y[1]**2

        return self.dy

代码可以通过

运行

test.py

import numpy as np 
import matplotlib.pyplot as plt
import pyximport
pyximport.install(setup_args={'include_dirs': np.get_include()})
from cythoncode import rkf, F

x0 = np.array([1, 0], dtype=np.float64)
f = F(a=0.1)

t = np.linspace(0, 30, 100)
y = rkf(f, t, x0)

plt.plot(t, y)
plt.show()