我正在使用Scikit-learn库进行线性回归。一切都很简单明了。 用6行代码,我可以完成工作。但是,我想确切地了解背后的情况。
由于我是ML的初学者,也许我的问题是错的,但是我想知道Scikit-learn使用什么算法来最小化其linear_regression方法中的均方误差。
答案 0 :(得分:2)
从实现的角度来看,这只是包装为预测对象的普通最小二乘(scipy.linalg.lstsq)。
您也可以在源代码here中调用linalg.lstsq。
关于发生了什么的额外说明:
如果线性公式为a * x + b
,则可以使用属性a
和b
访问系数(coef_
)和偏差(intercept_
)。训练过的模型。
一个玩具示例,该示例从3个点生成对角线以显示coef_
和intercept_
属性:
from sklearn.linear_model import LinearRegression
X = np.array([[1], [2], [3]])
y = np.array([1, 2, 3])
lg = LinearRegression()
lg.fit(X, y)
lg.coef_ # 1
lg.intercept_ # ~ 0
答案 1 :(得分:1)
Scikit-Learn的LinearRegression使用封闭形式的解决方案,即OLS解决方案。它专门使用scipy的Ordinary Least Squares求解器。