odeint:根据规则“安全”,无法将数组数据从dtype('complex128')转换为dtype('float64')

时间:2020-04-13 21:41:17

标签: python scipy differential-equations odeint ode45

以下代码给出错误:无法根据规则“安全”将数组数据从dtype('complex128')转换为dtype('float64')

import numpy as np
from numpy.fft import fft
from scipy.integrate import odeint

t = np.linspace(0,9,10)

def func(y, t):
    k = 0
    dydt = fft(y)
    return dydt

y0 = 0
y = odeint(func, y0, t)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-4885da912033> in <module>
     10 
     11 y0 = 0
---> 12 y = odeint(func, y0, t)

~\AppData\Local\Continuum\anaconda3\envs\udacityDL\lib\site-packages\scipy\integrate\odepack.py in odeint(func, y0, t, args, Dfun, col_deriv, full_output, ml, mu, rtol, atol, tcrit, h0, hmax, hmin, ixpr, mxstep, mxhnil, mxordn, mxords, printmessg, tfirst)
    243                              full_output, rtol, atol, tcrit, h0, hmax, hmin,
    244                              ixpr, mxstep, mxhnil, mxordn, mxords,
--> 245                              int(bool(tfirst)))
    246     if output[-1] < 0:
    247         warning_msg = _msgs[output[-1]] + " Run with full_output = 1 to get quantitative information."

TypeError: Cannot cast array data from dtype('complex128') to dtype('float64') according to the rule 'safe'

但是,如果我从func返回实值(而不是复数),例如:

def func(y, t):
    k = 0
    dydt = fft(y)
    return np.abs(dydt)

然后odeint正常工作。

任何人都可以帮助我确定/解决此问题的根源吗?

谢谢!

2 个答案:

答案 0 :(得分:0)

您正在更改数据类型并返回复杂的值,而ODE求解器将在此期望真实值。

您还可以尝试使您的输入变得复杂,从而至少在这一点上不会产生矛盾

y0 = 0+0j

您期望解决方案到底是什么?通过任何合理的解释,您将获得零函数。

答案 1 :(得分:0)

谢谢卢兹!基本上,我正在尝试求解以下(薛定))微分方程:

u_t = i/2 u_xx + |u|^2 u

通过两边的傅立叶变换,我们得到:

fft(u_t) = -i/2 k^2 fft(u) + i fft(|u|^2 u)

其中,u_t是u w.r.t.的一阶导数。时间,k是波数。以下是我的更正/修改后的代码。

import numpy as np
from numpy.fft import fft
from scipy.integrate import odeint

def f(t, ut, k):
    u = ifft(ut)
    return -(1j/2) * k**2 * ut + 1j * fft( np.abs(u)**2 * u )



# Simulation
L = 30
n = 512
x2 = np.linspace(-L/2, L/2, n+1)
x = x2[0:n]

t = np.linspace(0, 2*np.pi, 41)
dt = t[1]-t[0]

k = 2*np.pi/L * np.concatenate( [np.arange(0, n/2), np.arange(-n/2,0,1)] )
print(x.shape, t.shape, k.shape)

# Initial solutions
u = 1 / np.cosh(x)
ut = fft(u)

# Solution for first value of k
y0, t0 = ut[0], t[0]
utsol = ode(f).set_integrator('zvode', method='bdf')
utsol.set_initial_value(y0, t0).set_f_params(k[0])
for i in t:
    usol.append( utsol.integrate(i+dt) ) # no idea if using i+dt is correct!
usol = np.array(usol)

这应该为usol提供len(t)x 1的形状

对于k中的每个值,我的最终解应该是len(t)x n的形状,因为k有n个元素。

我还使用Matlab中的ode45解决了这个问题。从Python获得的解决方案与我在Matlab中发现的解决方案有很大的不同。我知道Matlab解决方案是正确的。