在Python中使用线性代数进行线性回归

时间:2019-04-26 15:55:03

标签: python linear-regression linear-algebra

我在维基百科(https://en.wikipedia.org/wiki/Coefficient_of_determination)上解释这些公式吗?  在Python中出错?下面是我尝试过的。

ssres

def ss_res(X, y, theta):

    y_diff=[]
    y_pred = X.dot(theta)

    for i in range(0, len(y)):
        y_diff.append((y[i]-y_pred[i])**2)

    return np.sum(y_diff)

输出看起来正确,但数字略有变化……就像一些小数点。

stderror

def std_error(X, y, theta):


    delta = (1/(len(y)-X.shape[1]+1))*(ss_res(X,y,theta))
    matrix1=matrix_power((X.T.dot(X)),-1)
    thing2=delta*matrix1
    thing3=scipy.linalg.sqrtm(thing2)

    res=np.diag(thing3)
    serr=np.reshape(res, (6, 1))
    return serr

std_error_array=std_error(X,y,theta)

1 个答案:

答案 0 :(得分:2)

您是否希望+1中的delta取决于您的X是否包含“常量”列(即所有值= 1)

否则,如果不是Python风格的话,看起来还可以。我很想将它们写为:

import numpy as np
from numpy.linalg import inv
from scipy.linalg import sqrtm

def solve_theta(X, Y):
    return np.linalg.solve(X.T @ X, X.T @ Y)

def ss_res(X, Y, theta):
    res = Y - (X @ theta)
    return np.sum(res ** 2)

def std_error(X, Y, theta):
    nr, rank = X.shape
    resid_df = nr - rank
    residvar = ss_res(X, Y, theta) / resid_df
    var_theta = residvar * inv(X.T @ X)
    return np.diag(sqrtm(var_theta))[:,None]

注意:这使用Python 3.5 style matrix multiply operator @而不是写出.dot()

这种算法的数值稳定性并不令人惊讶,您可能想看看使用SVD或QR分解。

中有一个平易近人的描述,您将如何使用SVD进行操作:
  

John Mandel(1982)“在回归分析中使用奇异值分解” 10.1080/00031305.1982.10482771

我们可以通过创建一些虚拟数据来进行测试:

np.random.seed(42)

N = 20
K = 3

true_theta = np.random.randn(K, 1) * 5
X = np.random.randn(N, K)
Y = np.random.randn(N, 1) + X @ true_theta

并在上面运行上面的代码:

theta = solve_theta(X, Y)
sse = std_error(X, Y, theta)

print(np.column_stack((theta, sse)))

给出:

[[ 2.23556391  0.35678574]
 [-0.40643163  0.24751913]
 [ 3.14687637  0.26461827]]

我们可以使用statsmodels进行测试:

import statsmodels.api as sm

sm.OLS(Y, X).fit().summary()

给出:

                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
x1             2.2356      0.358      6.243      0.000       1.480       2.991
x2            -0.4064      0.248     -1.641      0.119      -0.929       0.116
x3             3.1469      0.266     11.812      0.000       2.585       3.709

非常接近