对于scikit-learn中LinearRegression中的fit方法,为什么需要包含y坐标的第二个参数?

时间:2017-12-06 16:58:48

标签: python scikit-learn

sklearn.linear_model中的LinearRegression对象可用于将数据点拟合到一条线。从下面的代码可以看出,fit方法有两个参数,即点列表和另一个y坐标列表。

from sklearn import linear_model
reg = linear_model.LinearRegression()
reg.fit ([[0, 0], [1, 1], [2, 2]], [0, 1, 2])

我的问题是:为什么第二个参数甚至需要?这不是多余的信息吗?

2 个答案:

答案 0 :(得分:2)

  

使数据点符合一行

它非常适合您的数据点。

  

fit方法有两个参数,即点列表和另一个y坐标列表。

X是您的数据样本,其中每一行都是数据点(一个样本,一个N维特征向量)。 y是数据点标签,每个数据点一个。 fit方法找到矩阵W(要素权重)和向量b(偏差),以便最小化预测yhat = Wx + b与真实y之间的距离。

E.g。如果您获得了坐标为[x,y]的二维数据点,并且您希望基于y预测x,则会将x作为第一个参数传递给{{1} } s作为y的第二个参数。

答案 1 :(得分:2)

线性模型不仅限于1个预测变量和1个响应变量。换句话说,您可以使用X和Y作为预测响应变量Z的两个预测变量,其中Z可能线性地依赖于X和Y.在您的情况下,您只是尝试从X预测Y,因此更改代码以下内容:

from sklearn import linear_model
reg = linear_model.LinearRegression()
reg.fit ([[0], [1], [2]], [0, 1, 2])