scipy optimize.curve_fit不能适合其返回值取决于条件的函数

时间:2012-11-09 19:31:34

标签: python python-2.7 scipy

我想将一个定义如下的函数拟合为时间序列数据:

def func(t, a0, a1, a2, T, tau1, tau2):
    if t < T:
        return a0 + a1 * np.exp(-t/tau1) + a2 * np.exp(-t/tau2)
    else:
        return a0 + a1 * np.exp(-T/tau1) * (1 - t/tau1 + T/tau1) + a2 * np.exp(-T/tau2) * (1 - t/tau2 + T/tau2) 

这里,t表示进行测量的时间,其余参数是函数的参数。问题在于,当我将它提供给curve_fit时,Python会抱怨t&lt; T比较。我相信这是因为当在curve_fit中调用func时t变为数据点列表,而T是一个数字(不是列表):

popt, pcov = curve_fit(func, t1, d1)

其中t1是次数列表,d1是在相应时间测量的数据值列表。我尝试过多种方法来解决这个问题,但无济于事。有什么建议吗?非常感谢!

1 个答案:

答案 0 :(得分:5)

没错,t < T是一个布尔数组。 NumPy拒绝为布尔数组分配一个真值,因为有很多可能的选择 - 如果所有元素都是True,或者任何元素是True,那么它应该是True吗? / p>

但那没关系。在这种情况下,NumPy提供了一个很好的函数来替换if ... else ...块,即np.where

def func(t, a0, a1, a2, T, tau1, tau2):
    return np.where(
        t < T,
        a0 + a1 * np.exp(-t/tau1) + a2 * np.exp(-t/tau2),
        a0 + a1 * np.exp(-T/tau1) * (1 - t/tau1 + T/tau1) + a2 * np.exp(-T/tau2) * (1 - t/tau2 + T/tau2) )