在Python

时间:2015-08-28 12:27:30

标签: python-2.7 curve-fitting

我试图关注并重新使用一段名为@ThePredator的人建议的代码(用我自己的数据)(我无法评论该线程,因为我目前没有50的必需声誉)。完整代码如下:

import numpy as np # This is the Numpy module
from scipy.optimize import curve_fit # The module that contains the curve_fit routine
import matplotlib.pyplot as plt # This is the matplotlib module which we use for plotting the result

""" Below is the function that returns the final y according to the conditions """

def fitfunc(x,a1,a2):
    y1 = (x**(a1) )[x<xc]
    y2 = (x**(a1-a2) )[x>xc]
    y3 = (0)[x==xc]
    y = np.concatenate((y1,y2,y3))
    return y

x = array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460])
y = array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17])

""" In the above code, we have imported 3 modules, namely Numpy, Scipy and  matplotlib """

popt,pcov = curve_fit(fitfunc,x,y,p0=(10.0,1.0)) #here we provide random initial parameters a1,a2

a1 = popt[0] 
a2 = popt[1]
residuals = y - fitfunc(x,a1,a2)
chi-sq = sum( (residuals**2)/fitfunc(x,a1,a2) ) # This is the chi-square for your fitted curve

""" Now if you need to plot, perform the code below """
curvey = fitfunc(x,a1,a2) # This is your y axis fit-line

plt.plot(x, curvey, 'red', label='The best-fit line')
plt.scatter(x,y, c='b',label='The data points')
plt.legend(loc='best')
plt.show()

运行此代码时出现问题,我得到的错误如下:

y3 =(0)[x == xc]

TypeError:'int'对象没有属性' getitem '

还有:

xc未定义

我没有看到代码中缺少任何内容(不应该定义xc?)。

作者(@ThePredator)或其他有此知识的人请帮我辨别一下我没见过的。

  • 新版代码:

    import numpy as np # This is the Numpy module
    from scipy.optimize import curve_fit 
    import matplotlib.pyplot as plt 
    
    def fitfunc(x, a1, a2, xc):
        if x.all() < xc:
          y = x**a1
        elif x.all() > xc:
          y = x**(a1 - a2) * x**a2
        else:
          y = 0
        return y
    
    xc = 2
    x = np.array([0.001, 0.524, 0.625, 0.670, 0.790, 0.910, 1.240, 1.640, 2.180, 35460])
    y = np.array([7.435e-13, 3.374e-14, 1.953e-14, 3.848e-14, 4.510e-14, 5.702e-14, 5.176e-14, 6.0e-14,3.049e-14,1.12e-17])
    
    popt,pcov = curve_fit(fitfunc,x,y,p0=(1.0,1.0)) 
    
    a1 = popt[0] 
    a2 = popt[1]
    residuals = y - fitfunc(x, a1, a2, xc)
    chisq = sum((residuals**2)/fitfunc(x, a1, a2, xc)) 
    curvey = [fitfunc(val, a1, a2, xc) for val in x] #  y-axis fit-line
    
    plt.plot(x, curvey, 'red', label='The best-fit line')
    plt.scatter(x,y, c='b',label='The data points')
    plt.legend(loc='best')
    plt.show()
    

2 个答案:

答案 0 :(得分:0)

您的代码中存在多个错误/拼写错误。

1)你不能在Python的变量名中使用-(例如,卡方应为chi_square

2)您应from numpy import array或将array替换为np.array。目前,未定义名称array

3)xc未定义,您应在调用fitfunc()之前进行设置。

4)y3 = (0)[x==xc]无效,应该(我认为)y3 = np.zeros(len(x))[x==xc]y3 = np.zeros(np.sum(x==xc))

您使用fit_function()是错误的,因为它会改变图像的顺序。你想要的是:

def fit_function(x, a1, a2, xc):
    if x < xc:
        y = x**a1
    elif x > xc:
        y = x**(a1 - a2) * x**a2
    else:
        y = 0
    return y
xc = 2 #or any value you want
curvey = [fit_function(val, a1, a2, xc) for val in x]

答案 1 :(得分:0)

您可以执行以下操作来定义您的功能,它将解决。 x是一个数组(或列表),它应该将y作为数组(或列表)返回。然后你可以在curvefit中使用它。

def fit_function(x, a1, a2, xc):
    y = []
    for xx in x:
        if xx<xc:
            y.append(x**a1)
        elif xx>xc:
            y.append(x**(a1 - a2) * x**a2)
        else:
            y.append(0.0)
    return y