约束截距的线性回归

时间:2019-11-02 07:30:03

标签: python regression linear-regression

我确实想用截距值进行约束线性回归,例如: lowerbound <=拦截<=上限。

我知道我可以使用一些python库来约束系数,但是找不到可以约束截距的库。

我想要的是在截距在我定义的范围内的约束下,以最小的可能误差获得最适合我的数据点的最佳解决方案。

如何在python中做到这一点?

1 个答案:

答案 0 :(得分:1)

这里是使用带有参数范围的curve_fit的示例。在此示例中,参数“ a”是无界的,参数“ b”是有界的,拟合的值在那些边界内,参数“ c”是有界的,拟合的值在边界。

.innerHTML

更新:根据评论,这是一个多元拟合示例:

import numpy
import matplotlib
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

xData = numpy.array([5.0, 6.1, 7.2, 8.3, 9.4])
yData = numpy.array([ 10.0,  18.4,  20.8,  23.2,  35.0])


def standardFunc(data, a, b, c):
    return a * data + b * data**2 + c


# some initial parameter values - must be within bounds
initialParameters = numpy.array([1.0, 1.0, 1.0])

# bounds on parameters - initial parameters must be within these
lowerBounds = (-numpy.Inf, -100.0, -5.0)
upperBounds = (numpy.Inf, 100.0, 5.0)
parameterBounds = [lowerBounds, upperBounds]

fittedParameters, pcov = curve_fit(standardFunc, xData, yData, initialParameters, bounds = parameterBounds)

# values for display of fitted function
a, b, c = fittedParameters

# for plotting the fitting results
xPlotData = numpy.linspace(min(xData), max(xData), 50)
y_plot = standardFunc(xPlotData, a, b, c)

plt.plot(xData, yData, 'D') # plot the raw data as a scatterplot
plt.plot(xPlotData, y_plot) # plot the equation using the fitted parameters
plt.show()

print('fitted parameters:', fittedParameters)