如何在Python中使用Curvefit

时间:2019-03-09 14:41:05

标签: python scikit-learn scipy curve-fitting

我正在用python研究非线性curvefit。
我做了下面的例子。
但是优化图绘制得不好

plt.plot(basketCont, fittedData)

我认为优化的参数也不好。
你能给些建议吗?谢谢。

import matplotlib
matplotlib.use('Qt4Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
from scipy.optimize import curve_fit 

def func(x, a, b, c):
    return a - b* np.exp(c * x) 

baskets = np.array([475, 108, 2, 38, 320])
scaling_factor = np.array([95.5, 57.7, 1.4, 21.9, 88.8])

popt,pcov = curve_fit(func, baskets, scaling_factor)

print (popt)
print (pcov)

basketCont=np.linspace(min(baskets),max(baskets),50)
fittedData=[func(x, *popt) for x in basketCont]

fig1 = plt.figure(1)

plt.scatter(baskets, scaling_factor, s=5)
plt.plot(basketCont, fittedData)

plt.grid()

plt.show()

1 个答案:

答案 0 :(得分:2)

我个人无法使用您发布的方程式对您的数据进行很好的拟合,但是希尔S型方程式非常适合。这是我使用的图形拟合器的Python代码。

plot

import numpy, scipy, matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import warnings


baskets = numpy.array([475.0, 108.0, 2.0, 38.0, 320.0])
scaling_factor = numpy.array([95.5, 57.7, 1.4, 21.9, 88.8])

# rename data for simpler code re-use later
xData = baskets
yData = scaling_factor


def func(x, a, b, c): # Hill sigmoidal equation from zunzun.com
    return  a * numpy.power(x, b) / (numpy.power(c, b) + numpy.power(x, b)) 


# these are the same as the scipy defaults
initialParameters = numpy.array([1.0, 1.0, 1.0])

# do not print unnecessary warnings during curve_fit()
warnings.filterwarnings("ignore")

# curve fit the test data
fittedParameters, pcov = curve_fit(func, xData, yData, initialParameters)

modelPredictions = func(xData, *fittedParameters) 

absError = modelPredictions - yData

SE = numpy.square(absError) # squared errors
MSE = numpy.mean(SE) # mean squared errors
RMSE = numpy.sqrt(MSE) # Root Mean Squared Error, RMSE
Rsquared = 1.0 - (numpy.var(absError) / numpy.var(yData))

print('Parameters:', fittedParameters)
print('RMSE:', RMSE)
print('R-squared:', Rsquared)

print()


##########################################################
# graphics output section
def ModelAndScatterPlot(graphWidth, graphHeight):
    f = plt.figure(figsize=(graphWidth/100.0, graphHeight/100.0), dpi=100)
    axes = f.add_subplot(111)

    # first the raw data as a scatter plot
    axes.plot(xData, yData,  'D')

    # create data for the fitted equation plot
    xModel = numpy.linspace(min(xData), max(xData))
    yModel = func(xModel, *fittedParameters)

    # now the model as a line plot
    axes.plot(xModel, yModel)

    axes.set_xlabel('X Data') # X axis data label
    axes.set_ylabel('Y Data') # Y axis data label

    plt.show()
    plt.close('all') # clean up after using pyplot

graphWidth = 800
graphHeight = 600
ModelAndScatterPlot(graphWidth, graphHeight)