最小二乘法不适用于一组y

时间:2017-10-13 18:13:37

标签: python numpy matplotlib least-squares

我正在尝试使用numpy运行最小二乘算法并遇到问题。有人可以告诉我在给定代码中我做错了什么吗?当我将y设置为y = np.power(X, 1) + np.random.rand(20)*3或x的其他一些合理函数时,一切正常。但是对于那些给定y值定义的特定y,我得到的情节是没有意义的。

这是一种数字问题吗?

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

X = np.arange(1,21)
y = np.array([-0.00454712, -0.00457764, -0.0045166 , -0.00442505, -0.00427246,
       -0.00411987, -0.00378418, -0.003479  , -0.00314331, -0.00259399,
       -0.00213623, -0.00146484, -0.00082397, -0.00030518,  0.00027466,
        0.00076294,  0.00146484,  0.00192261,  0.00247192,  0.00314331])

#y = np.power(X, 1) + np.random.rand(20)*3

w = np.linalg.lstsq(X.reshape(20, 1), y)[0]

plt.plot(X, y, 'red')
plt.plot(X, X*w[0], 'blue')
plt.show()

1 个答案:

答案 0 :(得分:1)

您确定拟合的数据与y变量数据之间存在线性关系吗?

使用您示例中的代码(y = np.power(X, 1) + np.random.rand(20)*3),您可以在y变量本身内置一个线性关系(带有一些噪声),这样您的绘图就可以使用线性方程进行相对较好的跟踪。

X = np.arange(1,21)

#y = np.power(X, 1) + np.random.rand(20)*3

w = np.linalg.lstsq(X.reshape(20, 1), y)[0]

plt.plot(X, y, 'red')
plt.plot(X, X*w[0], 'blue')
plt.show()

Plot

但是,当您替换y变量

之类的东西时
    y = np.array([-0.00454712, -0.00457764, -0.0045166 , -0.00442505, -0.00427246,
       -0.00411987, -0.00378418, -0.003479  , -0.00314331, -0.00259399,
       -0.00213623, -0.00146484, -0.00082397, -0.00030518,  0.00027466,
        0.00076294,  0.00146484,  0.00192261,  0.00247192,  0.00314331])    

你最终得到的东西不那么容易。

Plot2

查看documentation,如果您正在尝试适合这组值的内容,则需要构建一个常量组件,在这种情况下,lstsq默认情况下不会这样做。 文档陈述lstsq

  

将最小二乘解返回到线性矩阵方程。

     

求解方程a x = b

如果您真的想要将数据拟合为线性方程式,运行如下所示的代码将为您提供几乎与原始数据匹配的内容。但是,此过程背后的数据似乎具有多项式/指数驱动程序,这将使polyfit更好。

X = np.arange(1,21)
y = np.array([-0.00454712, -0.00457764, -0.0045166 , -0.00442505, -0.00427246,
       -0.00411987, -0.00378418, -0.003479  , -0.00314331, -0.00259399,
       -0.00213623, -0.00146484, -0.00082397, -0.00030518,  0.00027466,
        0.00076294,  0.00146484,  0.00192261,  0.00247192,  0.00314331])

#y = np.power(X, 1) + np.random.rand(20)*3
X2 = np.vstack([X, np.ones(len(X))]).T
w = np.linalg.lstsq(X2, y)[0]


plt.plot(X, y, 'red')
plt.plot(X, X.dot(w[0])+w[1], 'blue')
plt.show()

Plot3