我的简单线性回归有什么问题?

时间:2017-11-24 05:18:39

标签: python-3.x numpy linear-regression

我无法使用此代码找到theta。

我添加了绘图代码以帮助可视化问题。

请帮助我找到这段代码中的错误

由于

Normal Equation used in linear function

#if defined(USE_MULTIBYTE) || defined(USE_UNICODE) 
#else 
#if defined(UNICODE) || defined(_UNICODE) 
#define USE_UNICODE 
#endif 
#endif 

#if defined(USE_UNICODE) 
typedef std::wstring xString; 
#else typedef std::string xString; 
#endif 

Output plot

1 个答案:

答案 0 :(得分:1)

错误位于绘图线中,应为

plot(arr(N), arr(N)**2 * theta[2] + arr(N) * theta[1] + theta[0], y)

根据二次多项式模型。

也;我想你出于说明的原因用这种方式计算了最小二乘解,但在实践中,使用np.linalg.lstsq得到的线性最小二乘拟合如下,代码更短更有效:

N = 20
x = np.arange(1, N+1)
y = x**2 + 3
basis = np.vstack((x**0, x**1, x**2)).T  # basis for the space of quadratic polynomials 
theta = np.linalg.lstsq(basis, y)[0]   # least squares approximation to y in this basis
plt.plot(x, y, 'ro')                   # original points
plt.plot(x, basis.dot(theta))          # best fit
plt.show()

fit