Reg:拉伸指数函数拟合中的错误

时间:2015-09-23 15:03:31

标签: python python-2.7 curve-fitting

我在.csv文件中有数据,它包含2列x和y轴。从.csv文件中读取轴,然后使用拉伸指数函数拟合数据,但显示错误。

这里我给出的示例数据易于理解。

我的功能是f(x) = a. exp (-b.t) ^ c + d。 (拉伸指数拟合)。我想根据这个函数拟合这些数据,我想要a,b,c和d的最终值。

我的编码是:

# Reading data
x=data[1,2,3,4,5,6,7,8,9,10]
y=data[7.2489, 7.0123, 7.0006, 7.0003, 7, 7, 7, 7, 7, 7]
# Fitting Streched Exponential Decay Curve
smoothx = np.linspace(x[0], x[-1], (5*x[-1]))
guess_a, guess_b, guess_c, guess_d = 4000, -0.005, 4, 4000
guess = [guess_a, guess_b, guess_c, guess_d]
f_theory1 = lambda t, a, b, c, d: a * np.exp((b*t)^(c)) + d
p, cov = curve_fit(f_theory1, x, y, p0=np.array(guess))
f_fit1 = lambda t: p[0] * np.exp((p[1] * t)^((p[2]))) + p[3]
plt.show()

在这里,我只展示猜测和拟合程序的一部分。

请更正我的代码中的错误,以便更好地适应。

1 个答案:

答案 0 :(得分:1)

您可以使用lmfit来调整参数。然后情节如下:

enter image description here

,相应的参数如下:

a:   56.8404075 
b:  -5.43686170 
c:   49.9888343 
d:   7.00146666 

lmfit的优势在于您还可以使用minmax参数轻松约束参数范围(请参阅下面的代码)。

这是产生情节的代码;请注意,我稍微修改了你的模型,以避免从负数计算根:

from lmfit import minimize, Parameters, Parameter, report_fit
import numpy as np

x=np.array([1,2,3,4,5,6,7,8,9,10]  )
y=np.array([7.2489, 7.0123, 7.0006, 7.0003, 7, 7, 7, 7, 7, 7])  


def f_theory1(params, x, data):  
    a = params['a'].value
    b = params['b'].value
    c = params['c'].value
    d = params['d'].value

    model = a * np.exp(b*(x**c)) + d # now b can become negative; in your definition it could not

    return model - data #that's what you want to minimize

# create a set of Parameters
#'value' is the initial condition
#'min' and 'max' define your boundaries
params = Parameters()
params.add('a', value= 40, min=-10, max=10000) 
params.add('b', value= -0.005, min=-10, max=200)
params.add('c', value= .03, min=-10, max=400) 
params.add('d', value= 40.0, min=-10, max=400) 

# do fit, here with leastsq model
result = minimize(f_theory1, params, args=(x, y))

# calculate final result
final = y + result.residual

# write error report
report_fit(params)

#plot results
try:
    import matplotlib.pyplot as plt
    plt.plot(x, y, 'k+')
    plt.plot(x, final, 'r')
    plt.ylim([6.95, 7.3])
    plt.show()
except:
    pass