我正在使用scipy.optimize.leastsq
进行曲线拟合。例如。对于高斯:
def fitGaussian(x, y, init=[1.0,0.0,4.0,0.1]):
fitfunc = lambda p, x: p[0]*np.exp(-(x-p[1])**2/(2*p[2]**2))+p[3] # Target function
errfunc = lambda p, x, y: fitfunc(p, x) - y # Distance to the target function
final, success = scipy.optimize.leastsq(errfunc, init[:], args=(x, y))
return fitfunc, final
现在,我想选择修复拟合中某些参数的值。我发现建议使用不同的包lmfit,我希望避免使用,或者非常通用,例如here。 因为我需要一个解决方案
我想出了以下内容,在每个参数上使用条件:
def fitGaussian2(x, y, init=[1.0,0.0,4.0,0.1], fix = [False, False, False, False]):
fitfunc = lambda p, x: (p[0] if not fix[0] else init[0])*np.exp(-(x-(p[1] if not fix[1] else init[1]))**2/(2*(p[2] if not fix[2] else init[2])**2))+(p[3] if not fix[3] else init[3])
errfunc = lambda p, x, y: fitfunc(p, x) - y # Distance to the target function
final, success = scipy.optimize.leastsq(errfunc, init[:], args=(x, y))
return fitfunc, final
虽然这种方法很好,但它既不实用也不美观。 所以我的问题是:是否有更好的方法在固定参数的scipy中执行曲线拟合?或者是否有包装,已经包含这样的参数修复?
答案 0 :(得分:1)
使用scipy
,我没有内置选项。你将永远像你已经做的那样进行解决。
如果您愿意使用包装包,我可以推荐自己的symfit
吗?这是scipy
的包装器,具有可读性和较少的样板代码作为其核心原则。在symfit
中,您的问题将解决为:
from symfit import parameters, variables, exp, Fit, Parameter
a, b, c, d = parameters('a, b, c, d')
x, y = variables('x, y')
model_dict = {y: a * exp(-(x - b)**2 / (2 * c**2)) + d}
fit = Fit(model_dict, x=xdata, y=ydata)
fit_result = fit.execute()
行a, b, c, d = parameters('a, b, c, d')
生成四个Parameter
个对象。修复例如参数c
为其初始值,请在调用fit.execute()
前执行以下操作:
c.value = 4.0
c.fixed = True
因此可能的最终结果可能是:
from symfit import parameters, variables, exp, Fit, Parameter
a, b, c, d = parameters('a, b, c, d')
x, y = variables('x, y')
c.value = 4.0
c.fixed = True
model_dict = {y: a * exp(-(x - b)**2 / (2 * c**2)) + d}
fit = Fit(model_dict, x=xdata, y=ydata)
fit_result = fit.execute()
如果您希望代码更具动态性,可以使用以下方式立即制作Parameter
个对象:
c = Parameter(4.0, fixed=True)
有关详细信息,请查看文档:{{3}}
答案 1 :(得分:-1)
上面使用symfit的例子肯定只是拟合方法的语法,但是,给出的例子是否真的约束了变量c?
如果查看fit_result.param
,您会收到以下信息:
OrderedDict([('a', 16.374368575343127),
('b', 0.49201249437123556),
('c', 0.5337962977235504),
('d', -9.55593614465743)])
参数c不是4.0。