如何强制scipy.optimize中的leastsq函数优于每个数据点

时间:2016-07-13 09:08:33

标签: python numpy optimization scipy

我目前正在尝试使用scipy.optimize中的minimalsq方法计算一个函数来拟合某些数据点。

我正在寻找的功能类似于f(x) = A * cos(b*x + c),其中A,b,c是我想知道的参数。

我的代码到目前为止:

def residuals(p, y, x):
    A, b, c = p
    err = y - A * cos(b * x + c)
    return err


x = arange(-8, 9)
y = [0.060662282, 0.25381372, 0.357635814, 0.610186219, 0.689421037, 0.987387563,
 1.062490593, 1.09941534, 1.04789242, 1.05323342, 0.947636751, 0.929896615, 0.758757134, 0.572468578,
 0.422551093, 0.25694886, 0.029750763]

# The true parameters
A, b, c = 1.1, 0.2, 0.01
y_true = A * cos(b * x + c)

y_meas = array(y)

# initial guess
p0 = [1.0, 0.1, 0.05]
array(p0)

plsq = optimization.leastsq(residuals, p0, args=(y_meas, x))
print plsq[0]

这个想要的回报:

[1.07861728 0.19353103 0.00361659]

这项工作很好但我希望用这些参数计算的函数f(x) = A * cos(b * x + c)优于每个数据点。 换句话说,对于所有数据点(xdata; ydata),我希望f(xdata_i) > ydata_i

如果1.e6,我已经尝试在遗留函数中返回一个较大的值err > 0。但是,lesssq函数似乎并不欣赏它并为我返回一条不理解的错误消息。

1 个答案:

答案 0 :(得分:2)

一种方法是将您的问题从头开始作为受约束的优化问题,并通过例如scipy.optimize.minimize解决。

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize

# data set
x = np.arange(-8, 9)
y = np.array([0.060662282, 0.25381372, 0.357635814, 0.610186219, 0.689421037, 0.987387563,
 1.062490593, 1.09941534, 1.04789242, 1.05323342, 0.947636751, 0.929896615, 0.758757134, 0.572468578,
 0.422551093, 0.25694886, 0.029750763])


# define the LS fit as the objective function
def obj(z):
    a, b, c = z

    return (np.abs(y - a * np.cos(b * x + c))**2).sum()

# define constraint that the fit should be larger than the samples
def constraint(z):
    a, b, c = z

    return a * np.cos(b * x + c) - y

# required input for 'minimize' function
cons = ({'type': 'ineq', 'fun': constraint},)

z0 = (0,0,0) # provide an initial estimate of the parameters 
sol = minimize(obj, z0, constraints = cons)
A_opt, b_opt, c_opt = sol.x
print (A_opt, b_opt, c_opt)

#plot fit
x_range = np.linspace(-8,9,100)
plt.plot(x,y,'o')
plt.plot(x_range, A_opt * np.cos(b_opt * x_range + c_opt) )
1.15736059083 0.18957445657 0.0198968389239

enter image description here