scipy.optimize.curve_fit设置“固定”参数

时间:2015-07-29 15:51:49

标签: python numpy

我正在使用scipy的curve_fit来使用高斯函数逼近数据中的峰值。这适用于强峰,但对于较弱的峰值则更难。但是,我认为修复一个参数(比如高斯的宽度)会对此有所帮助。我知道我可以设置初始的“估计”,但有没有办法可以轻松定义单个参数而不改变我适合的函数?

2 个答案:

答案 0 :(得分:0)

如果您要“固定”拟合函数的参数,则可以定义一个新的拟合函数,该函数利用原始的拟合函数,但将一个参数设置为固定值:

custom_gaussian = lambda x, mu: gaussian(x, mu, 0.05)

这是将Gaussian function中的sigma固定为0.05(而不是最优值0.1)的完整示例。当然,这在这里并没有真正意义,因为该算法在寻找最佳值时没有问题。但是,您可以看到在固定mu的情况下仍如何找到sigma

import matplotlib.pyplot as plt
import numpy as np
import scipy.optimize 

def gaussian(x, mu, sigma):
    return 1 / sigma / np.sqrt(2 * np.pi) * np.exp(-(x - mu)**2 / 2 / sigma**2)

# Create sample data
x = np.linspace(0, 2, 200)
y = gaussian(x, 1, 0.1) + np.random.rand(*x.shape) - 0.5
plt.plot(x, y, label="sample data")

# Fit with original fit function
popt, _ = scipy.optimize.curve_fit(gaussian, x, y)
plt.plot(x, gaussian(x, *popt), label="gaussian")

# Fit with custom fit function with fixed `sigma`
custom_gaussian = lambda x, mu: gaussian(x, mu, 0.05)
popt, _ = scipy.optimize.curve_fit(custom_gaussian, x, y)
plt.plot(x, custom_gaussian(x, *popt), label="custom_gaussian")

plt.legend()
plt.show()

figure

答案 1 :(得分:-1)

希望这很有帮助。不得不使用hax。 Curve_fit对所需的内容非常严格。

import numpy as np
from numpy import random
import scipy as sp
from scipy.optimize import curve_fit
import matplotlib.pyplot as pl

def exp1(t,a1,tau1):
    #A1*exp(-t/t1)
    val=0.
    val=(a1*np.exp(-t/tau1))*np.heaviside(t,0)
    return val

def wrapper(t,*args):

    global hold
    global p0
    wrapperName='exp1(t,'
    for i in range(0, len(hold)):
        if hold[i]:
            wrapperName+=str(p0[i])
        else:
            if i%2==0:
                wrapperName+='args['+str(i)+']'
            else:
                wrapperName+='args'+str(i)+']'
        if i<len(hold):
            wrapperName+=','
    wrapperName+=')'

    return eval(wrapperName)

p0=np.array([1.5,500.])
hold=np.array([0,1])
p1=np.delete(p0,1)

timepoints = np.arange(0.,2000.,20.)
y=exp1(timepoints,1,1000)+np.random.normal(0, .1, size=len(timepoints))

popt, pcov = curve_fit(exp1, timepoints, y, p0=p0)
print 'unheld parameters:', popt, pcov

popt, pcov = curve_fit(wrapper, timepoints, y, p0=p1)
for i in range(0, len(hold)):
    if hold[i]:
        popt=np.insert(popt,i,p0[i])
yfit=exp1(timepoints,popt[0],popt[1])
pl.plot(timepoints,y,timepoints,yfit)
pl.show()
print 'hold parameters:', popt, pcov