我正在用python研究非线性curvefit。
我做了下面的例子。
但是优化图绘制得不好
plt.plot(basketCont, fittedData)
我认为优化的参数也不好。
你能给些建议吗?谢谢。
import matplotlib
matplotlib.use('Qt4Agg')
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import numpy as np
from scipy.optimize import curve_fit
def func(x, a, b, c):
return a - b* np.exp(c * x)
baskets = np.array([475, 108, 2, 38, 320])
scaling_factor = np.array([95.5, 57.7, 1.4, 21.9, 88.8])
popt,pcov = curve_fit(func, baskets, scaling_factor)
print (popt)
print (pcov)
basketCont=np.linspace(min(baskets),max(baskets),50)
fittedData=[func(x, *popt) for x in basketCont]
fig1 = plt.figure(1)
plt.scatter(baskets, scaling_factor, s=5)
plt.plot(basketCont, fittedData)
plt.grid()
plt.show()
答案 0 :(得分:2)
我个人无法使用您发布的方程式对您的数据进行很好的拟合,但是希尔S型方程式非常适合。这是我使用的图形拟合器的Python代码。
import numpy, scipy, matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import warnings
baskets = numpy.array([475.0, 108.0, 2.0, 38.0, 320.0])
scaling_factor = numpy.array([95.5, 57.7, 1.4, 21.9, 88.8])
# rename data for simpler code re-use later
xData = baskets
yData = scaling_factor
def func(x, a, b, c): # Hill sigmoidal equation from zunzun.com
return a * numpy.power(x, b) / (numpy.power(c, b) + numpy.power(x, b))
# these are the same as the scipy defaults
initialParameters = numpy.array([1.0, 1.0, 1.0])
# do not print unnecessary warnings during curve_fit()
warnings.filterwarnings("ignore")
# curve fit the test data
fittedParameters, pcov = curve_fit(func, xData, yData, initialParameters)
modelPredictions = func(xData, *fittedParameters)
absError = modelPredictions - yData
SE = numpy.square(absError) # squared errors
MSE = numpy.mean(SE) # mean squared errors
RMSE = numpy.sqrt(MSE) # Root Mean Squared Error, RMSE
Rsquared = 1.0 - (numpy.var(absError) / numpy.var(yData))
print('Parameters:', fittedParameters)
print('RMSE:', RMSE)
print('R-squared:', Rsquared)
print()
##########################################################
# graphics output section
def ModelAndScatterPlot(graphWidth, graphHeight):
f = plt.figure(figsize=(graphWidth/100.0, graphHeight/100.0), dpi=100)
axes = f.add_subplot(111)
# first the raw data as a scatter plot
axes.plot(xData, yData, 'D')
# create data for the fitted equation plot
xModel = numpy.linspace(min(xData), max(xData))
yModel = func(xModel, *fittedParameters)
# now the model as a line plot
axes.plot(xModel, yModel)
axes.set_xlabel('X Data') # X axis data label
axes.set_ylabel('Y Data') # Y axis data label
plt.show()
plt.close('all') # clean up after using pyplot
graphWidth = 800
graphHeight = 600
ModelAndScatterPlot(graphWidth, graphHeight)