使用scipy的solve_ivp在时间步之间运行代码

时间:2019-10-17 18:51:33

标签: python scipy

我正在将我的代码从使用scipy的odeint转换为scipy的solve_ivp。当使用odeint时,我将使用while循环,如下所示:

while solver.successful() : 
    solver.integrate(t_final, step=True)
    # do other operations

该方法使我可以在每个时间步长后存储取决于解决方案的值。

我现在要切换为使用solve_ivp,但不确定如何使用solve_ivp求解器来实现此功能。有人用solve_ivp完成了此功能吗?

谢谢!

1 个答案:

答案 0 :(得分:0)

我想我知道你想问什么。我有一个程序,使用solve_ivp分别在每个时间步骤之间进行积分,然后使用值来计算下一次迭代的值。 (即热传递系数,传递系数等),我使用了两个嵌套的for循环。内部的for循环计算或完成每个步骤所需的操作。然后将每个值保存在列表或数组中,然后内部循环应终止。外循环仅应用于提供时间值并可能重新加载必要的常量。

例如:

for i in range(start_value, end_value, time_step):
start_time = i
end_time = i + time_step
# load initial values and used most recent values
    for j in range(0, 1, 1):


    answer = solve_ivp(function,(start_time,end_time), [initial_values])
    # Save new values at the end of a list storing all calculated values

假设您拥有

之类的系统
  1. d(Y1)/ dt = a1 * Y2 + Y1

  2. d(Y2)/ dt = a2 * Y1 + Y2

,您想从t = 0,10开始求解。步长为0.1。其中a1和a2是在其他地方计算或确定的值。该代码将起作用。

from scipy.integrate import solve_ivp
import sympy as sp
import numpy as np
import math
import matplotlib.pyplot as plt



def a1(n):
       return 1E-10*math.exp(n)

def a2(n):
       return 2E-10*math.exp(n)

def rhs(t,y, *args):
       a1, a2 = args
       return [a1*y[1] + y[0],a2*y[0] + y[1]]

Y1 = [0.02]
Y2 = [0.01]
A1 = []
A2 = []
endtime = 10 
time_step = 0.1
times = np.linspace(0,endtime, int(endtime/time_step)+1)
tsymb = sp.symbols('t')
ysymb = sp.symbols('y')
for i in range(0,endtime,1):

       for j in range(0,int(1/time_step),1):
              tstart = i + j*time_step
              tend = i + j*time_step + time_step
              A1.append(a1(tstart/100))
              A2.append(a2(tstart/100))
              Y0 = [Y1[-1],Y2[-1]]
              args = [A1[-1],A2[-1]]
              answer = solve_ivp(lambda tsymb, ysymb : rhs(tsymb,ysymb, *args), (tstart,tend), Y0)
              Y1.append(answer.y[0][-1])
              Y2.append(answer.y[1][-1])

fig = plt.figure()
plt1 = plt.plot(times,Y1, label = "Y1")
plt2 = plt.plot(times,Y2, label = "Y2")
plt.xlabel('Time')
plt.ylabel('Y Values')
plt.legend()
plt.grid()
plt.show()