scipy.integrate.solve_ivp矢量化

时间:2019-03-04 20:29:14

标签: python scipy ode

尝试将向量化选项用于solve_ivp,奇怪的是,它抛出一个错误,y0必须为一维。 MWE:

from scipy.integrate import solve_ivp
import numpy as np
import math

def f(t, y):
    theta = math.pi/4
    ham = np.array([[1,0],[1,np.exp(-1j*theta*t)]])
    return-1j * np.dot(ham,y)


def main():

    y0 = np.eye(2,dtype= np.complex128)
    t0 = 0
    tmax = 10**(-6)
    sol=solve_ivp( lambda t,y :f(t,y),(t0,tmax),y0,method='RK45',vectorized=True)
    print(sol.y)

if __name__ == '__main__':
    main()
  

主叫签名是fun(t,y)。这里t是一个标量,并且ndarray y有两个选择:它可以具有形状(n,);也可以具有形状(n,)。然后乐趣必须返回形状为(n,)的array_like。或者,它可以具有形状(n,k);那么fun必须返回一个形状为(n,k)的array_like,即每一列对应于y中的单个列。这两个选项之间的选择由矢量化参数确定(请参见下文)。向量化的实现可以通过有限的差分(刚性求解器需要)更快地逼近雅可比行列。

错误:

  

ValueError:y0必须是一维的。

Python 3.6.8

scipy。版本 '1.2.1'

1 个答案:

答案 0 :(得分:0)

这里vectorize的含义有些混乱。这并不意味着y0可以是2d,而是y传递给函数时可以是2d。换句话说,如果求解器愿意,func可以一次在多个点求值。有多少点取决于求解器,而不是您。

更改f,以在每次调用时显示y的形状:

def f(t, y):
    print(y.shape)
    theta = math.pi/4
    ham = np.array([[1,0],[1,np.exp(-1j*theta*t)]])
    return-1j * np.dot(ham,y)

一个示例通话:

In [47]: integrate.solve_ivp(f,(t0,tmax),[1j,0],method='RK45',vectorized=False) 
(2,)
(2,)
(2,)
(2,)
(2,)
(2,)
(2,)
(2,)
Out[47]: 
  message: 'The solver successfully reached the end of the integration interval.'
     nfev: 8
     njev: 0
      nlu: 0
      sol: None
   status: 0
  success: True
        t: array([0.e+00, 1.e-06])
 t_events: None
        y: array([[0.e+00+1.e+00j, 1.e-06+1.e+00j],
       [0.e+00+0.e+00j, 1.e-06-1.e-12j]])

相同的通话,但带有vectorize=True

In [48]: integrate.solve_ivp(f,(t0,tmax),[1j,0],method='RK45',vectorized=True)  
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
(2, 1)
Out[48]: 
  message: 'The solver successfully reached the end of the integration interval.'
     nfev: 8
     njev: 0
      nlu: 0
      sol: None
   status: 0
  success: True
        t: array([0.e+00, 1.e-06])
 t_events: None
        y: array([[0.e+00+1.e+00j, 1.e-06+1.e+00j],
       [0.e+00+0.e+00j, 1.e-06-1.e-12j]])

如果设置为False,则传递给y的{​​{1}}是(2,),1d;如果为True,则为(2,1)。我想如果求解器方法如此需要,它可能是(2,2)甚至是(2,3)。这样可以减少对f的调用,从而加快执行速度。在这种情况下,没关系。

f具有类似的quadrature布尔参数:

Numerical Quadrature of scalar valued function with vector input using scipy

相关的错误/问题讨论:

https://github.com/scipy/scipy/issues/8922