使用scipy curve_fit错误适合

时间:2015-11-17 11:50:11

标签: python numpy scipy curve-fitting

我试图将一些数据拟合到具有指数切断的幂律函数。我用numpy生成一些数据,我试图用scipy.optimization来适应这些数据。 这是我的代码:

import numpy as np
from scipy.optimize import curve_fit

def func(x, A, B, alpha):
    return A * x**alpha * np.exp(B * x)

xdata = np.linspace(1, 10**8, 1000)
ydata = func(xdata, 0.004, -2*10**-8, -0.75)
popt, pcov = curve_fit(func, xdata, ydata)
print popt

我得到的结果是:[1,1,1]与数据不对应。 我做错了什么?

2 个答案:

答案 0 :(得分:4)

虽然xnx给出了为什么curve_fit在这里失败的答案,我想我会提出一种不同的方法来解决你的功能形式的问题,这种问题不依赖于梯度下降(因此合理的初步猜测)

请注意,如果您获取适合的函数的日志,则会获得表单

\log f = \log A + \alpha \log x + B x

每个未知参数(log A,alpha,B)中都是线性的

因此我们可以使用线性代数机制通过以矩阵的形式写出等式来解决这个问题

log y = M p

其中log y是ydata点日志的列向量,p是未知参数的列向量,M是矩阵[[1], [log x], [x]]

或明确

enter image description here

然后可以使用np.linalg.lstsq

有效地找到最佳拟合参数向量

然后,代码中的示例问题可以写成

import numpy as np

def func(x, A, B, alpha):
    return A * x**alpha * np.exp(B * x)

A_true = 0.004
alpha_true = -0.75
B_true = -2*10**-8

xdata = np.linspace(1, 10**8, 1000)
ydata = func(xdata, A_true, B_true, alpha_true)

M = np.vstack([np.ones(len(xdata)), np.log(xdata), xdata]).T

logA, alpha, B = np.linalg.lstsq(M, np.log(ydata))[0]

print "A =", np.exp(logA)
print "alpha =", alpha
print "B =", B

很好地恢复了初始参数:

A = 0.00400000003736
alpha = -0.750000000928
B = -1.9999999934e-08

另请注意,此方法比使用curve_fit解决手头的问题快了约20倍

In [8]: %timeit np.linalg.lstsq(np.vstack([np.ones(len(xdata)), np.log(xdata), xdata]).T, np.log(ydata))
10000 loops, best of 3: 169 µs per loop


In [2]: %timeit curve_fit(func, xdata, ydata, [0.01, -5e-7, -0.4])
100 loops, best of 3: 4.44 ms per loop

答案 1 :(得分:2)

显然你的初始猜测(默认为1,因为你没有给出一个 - 见the docs)与实际参数相差太远以允许算法收敛。主要问题可能在于[1,1,1],如果为正,则会将指数函数发送到您提供的B的非常大的值。

尝试提供更接近实际参数的东西,它可以工作:

xdata

输出:

p0 = 0.01, -5e-7, -0.4    # Initial guess for the parameters
popt, pcov = curve_fit(func, xdata, ydata, p0)
print popt