我正在使用多个变量/特征进行线性回归。我试图通过使用正规方程方法(使用矩阵逆),Numpy最小二乘numpy.linalg.lstsq工具和np.linalg.solve工具来获得thetaas(系数)。在我的数据中,我有 n = 143 功能和 m = 13000 训练示例。
对于正则化的正规方程方法,我使用以下公式:
正则化用于解决矩阵不可逆性的潜在问题(XtX
矩阵可能变为奇异/不可逆)
数据准备代码:
import pandas as pd
import numpy as np
path = 'DB2.csv'
data = pd.read_csv(path, header=None, delimiter=";")
data.insert(0, 'Ones', 1)
cols = data.shape[1]
X = data.iloc[:,0:cols-1]
y = data.iloc[:,cols-1:cols]
IdentitySize = X.shape[1]
IdentityMatrix= np.zeros((IdentitySize, IdentitySize))
np.fill_diagonal(IdentityMatrix, 1)
对于最小二乘法方法,我使用Numpy的 numpy.linalg.lstsq 。这是Python代码:
lamb = 1
th = np.linalg.lstsq(X.T.dot(X) + lamb * IdentityMatrix, X.T.dot(y))[0]
我还使用numpy的 np.linalg.solve 工具:
lamb = 1
XtX_lamb = X.T.dot(X) + lamb * IdentityMatrix
XtY = X.T.dot(y)
x = np.linalg.solve(XtX_lamb, XtY);
对于正规方程式,我使用:
lamb = 1
xTx = X.T.dot(X) + lamb * IdentityMatrix
XtX = np.linalg.inv(xTx)
XtX_xT = XtX.dot(X.T)
theta = XtX_xT.dot(y)
在所有方法中,我都使用了正则化。以下是结果(theta系数),以了解这三种方法之间的差异:
Normal equation: np.linalg.lstsq np.linalg.solve
[-27551.99918303] [-27551.95276154] [-27551.9991855]
[-940.27518383] [-940.27520138] [-940.27518383]
[-9332.54653964] [-9332.55448263] [-9332.54654461]
[-3149.02902071] [-3149.03496582] [-3149.02900965]
[-1863.25125909] [-1863.2631435] [-1863.25126344]
[-2779.91105618] [-2779.92175308] [-2779.91105347]
[-1226.60014026] [-1226.61033117] [-1226.60014192]
[-920.73334259] [-920.74331432] [-920.73334194]
[-6278.44238081] [-6278.45496955] [-6278.44237847]
[-2001.48544938] [-2001.49566981] [-2001.48545349]
[-715.79204971] [-715.79664124] [-715.79204921]
[ 4039.38847472] [ 4039.38302499] [ 4039.38847515]
[-2362.54853195] [-2362.55280478] [-2362.54853139]
[-12730.8039209] [-12730.80866036] [-12730.80392076]
[-24872.79868125] [-24872.80203459] [-24872.79867954]
[-3402.50791863] [-3402.5140501] [-3402.50793382]
[ 253.47894001] [ 253.47177732] [ 253.47892472]
[-5998.2045186] [-5998.20513905] [-5998.2045184]
[ 198.40560401] [ 198.4049081] [ 198.4056042]
[ 4368.97581411] [ 4368.97175688] [ 4368.97581426]
[-2885.68026222] [-2885.68154407] [-2885.68026205]
[ 1218.76602731] [ 1218.76562838] [ 1218.7660275]
[-1423.73583813] [-1423.7369068] [-1423.73583793]
[ 173.19125007] [ 173.19086525] [ 173.19125024]
[-3560.81709538] [-3560.81650156] [-3560.8170952]
[-142.68135768] [-142.68162508] [-142.6813575]
[-2010.89489111] [-2010.89601322] [-2010.89489092]
[-4463.64701238] [-4463.64742877] [-4463.64701219]
[ 17074.62997704] [ 17074.62974609] [ 17074.62997723]
[ 7917.75662561] [ 7917.75682048] [ 7917.75662578]
[-4234.16758492] [-4234.16847544] [-4234.16758474]
[-5500.10566329] [-5500.106558] [-5500.10566309]
[-5997.79002683] [-5997.7904842] [-5997.79002634]
[ 1376.42726683] [ 1376.42629704] [ 1376.42726705]
[ 6056.87496151] [ 6056.87452659] [ 6056.87496175]
[ 8149.0123667] [ 8149.01209157] [ 8149.01236827]
[-7273.3450484] [-7273.34480382] [-7273.34504827]
[-2010.61773247] [-2010.61839251] [-2010.61773225]
[-7917.81185096] [-7917.81223606] [-7917.81185084]
[ 8247.92773739] [ 8247.92774315] [ 8247.92773722]
[ 1267.25067823] [ 1267.24677734] [ 1267.25067832]
[ 2557.6208133] [ 2557.62126916] [ 2557.62081337]
[-5678.53744654] [-5678.53820798] [-5678.53744647]
[ 3406.41697822] [ 3406.42040997] [ 3406.41697836]
[-8371.23657044] [-8371.2361594] [-8371.23657035]
[ 15010.61728285] [ 15010.61598236] [ 15010.61728304]
[ 11006.21920273] [ 11006.21711213] [ 11006.21920284]
[-5930.93274062] [-5930.93237071] [-5930.93274048]
[-5232.84459862] [-5232.84557665] [-5232.84459848]
[ 3196.89304277] [ 3196.89414431] [ 3196.8930428]
[ 15298.53309912] [ 15298.53496877] [ 15298.53309919]
[ 4742.68631183] [ 4742.6862601] [ 4742.68631172]
[ 4423.14798495] [ 4423.14765013] [ 4423.14798546]
[-16153.50854089] [-16153.51038489] [-16153.50854123]
[-22071.50792741] [-22071.49808389] [-22071.50792408]
[-688.22903323] [-688.2310229] [-688.22904006]
[-1060.88119863] [-1060.8829114] [-1060.88120546]
[-101.75750066] [-101.75776411] [-101.75750831]
[ 4106.77311898] [ 4106.77128502] [ 4106.77311218]
[ 3482.99764601] [ 3482.99518758] [ 3482.99763924]
[-1100.42290509] [-1100.42166312] [-1100.4229119]
[ 20892.42685103] [ 20892.42487476] [ 20892.42684422]
[-5007.54075789] [-5007.54265501] [-5007.54076473]
[ 11111.83929421] [ 11111.83734144] [ 11111.83928704]
[ 9488.57342568] [ 9488.57158677] [ 9488.57341883]
[-2992.3070786] [-2992.29295891] [-2992.30708529]
[ 17810.57005982] [ 17810.56651223] [ 17810.57005457]
[-2154.47389712] [-2154.47504319] [-2154.47390285]
[-5324.34206726] [-5324.33913623] [-5324.34207293]
[-14981.89224345] [-14981.8965674] [-14981.89224973]
[-29440.90545197] [-29440.90465897] [-29440.90545704]
[-6925.31991443] [-6925.32123144] [-6925.31992383]
[ 104.98071593] [ 104.97886085] [ 104.98071152]
[-5184.94477582] [-5184.9447972] [-5184.94477792]
[ 1555.54536625] [ 1555.54254362] [ 1555.5453638]
[-402.62443474] [-402.62539068] [-402.62443718]
[ 17746.15769322] [ 17746.15458093] [ 17746.15769074]
[-5512.94925026] [-5512.94980649] [-5512.94925267]
[-2202.8589276] [-2202.86226244] [-2202.85893056]
[-5549.05250407] [-5549.05416936] [-5549.05250669]
[-1675.87329493] [-1675.87995809] [-1675.87329255]
[-5274.27756529] [-5274.28093377] [-5274.2775701]
[-5424.10246845] [-5424.10658526] [-5424.10247326]
[-1014.70864363] [-1014.71145066] [-1014.70864845]
[ 12936.59360437] [ 12936.59168749] [ 12936.59359954]
[ 2912.71566077] [ 2912.71282628] [ 2912.71565599]
[ 6489.36648506] [ 6489.36538259] [ 6489.36648021]
[ 12025.06991281] [ 12025.07040848] [ 12025.06990358]
[ 17026.57841531] [ 17026.56827742] [ 17026.57841044]
[ 2220.1852193] [ 2220.18531961] [ 2220.18521579]
[-2886.39219026] [-2886.39015388] [-2886.39219394]
[-18393.24573629] [-18393.25888463] [-18393.24573872]
[-17591.33051471] [-17591.32838012] [-17591.33051834]
[-3947.18545848] [-3947.17487999] [-3947.18546459]
[ 7707.05472816] [ 7707.05577227] [ 7707.0547217]
[ 4280.72039079] [ 4280.72338194] [ 4280.72038435]
[-3137.48835901] [-3137.48480197] [-3137.48836531]
[ 6693.47303443] [ 6693.46528167] [ 6693.47302811]
[-13936.14265517] [-13936.14329336] [-13936.14267094]
[ 2684.29594641] [ 2684.29859601] [ 2684.29594183]
[-2193.61036078] [-2193.63086307] [-2193.610366]
[-10139.10424848] [-10139.11905454] [-10139.10426049]
[ 4475.11569903] [ 4475.12288711] [ 4475.11569421]
[-3037.71857269] [-3037.72118246] [-3037.71857265]
[-5538.71349798] [-5538.71654224] [-5538.71349794]
[ 8008.38521357] [ 8008.39092739] [ 8008.38521361]
[-1433.43859633] [-1433.44181824] [-1433.43859629]
[ 4212.47144667] [ 4212.47368097] [ 4212.47144686]
[ 19688.24263706] [ 19688.2451694] [ 19688.2426368]
[ 104.13434091] [ 104.13434349] [ 104.13434091]
[-654.02451175] [-654.02493111] [-654.02451174]
[-2522.8642551] [-2522.88694451] [-2522.86424254]
[-5011.20385919] [-5011.22742915] [-5011.20384655]
[-13285.64644021] [-13285.66951459] [-13285.64642763]
[-4254.86406891] [-4254.88695873] [-4254.86405637]
[-2477.42063206] [-2477.43501057] [-2477.42061727]
[ 0.] [ 1.23691279e-10] [ 0.]
[-92.79470071] [-92.79467095] [-92.79470071]
[ 2383.66211583] [ 2383.66209637] [ 2383.66211583]
[-10725.22892185] [-10725.22889937] [-10725.22892185]
[ 234.77560283] [ 234.77560254] [ 234.77560283]
[ 4739.22119578] [ 4739.22121432] [ 4739.22119578]
[ 43640.05854156] [ 43640.05848841] [ 43640.05854157]
[ 2592.3866707] [ 2592.38671547] [ 2592.3866707]
[-25130.02819215] [-25130.05501178] [-25130.02819515]
[ 4966.82173096] [ 4966.7946407] [ 4966.82172795]
[ 14232.97930665] [ 14232.9529959] [ 14232.97930363]
[-21621.77202422] [-21621.79840459] [-21621.7720272]
[ 9917.80960029] [ 9917.80960571] [ 9917.80960029]
[ 1355.79191536] [ 1355.79198092] [ 1355.79191536]
[-27218.44185748] [-27218.46880642] [-27218.44185719]
[-27218.04184348] [-27218.06875423] [-27218.04184318]
[ 23482.80743869] [ 23482.78043029] [ 23482.80743898]
[ 3401.67707434] [ 3401.65134677] [ 3401.67707463]
[ 3030.36383274] [ 3030.36384909] [ 3030.36383274]
[-30590.61847724] [-30590.63933424] [-30590.61847706]
[-28818.3942685] [-28818.41520495] [-28818.39426833]
[-25115.73726772] [-25115.7580278] [-25115.73726753]
[ 77174.61695995] [ 77174.59548773] [ 77174.61696016]
[-20201.86613672] [-20201.88871113] [-20201.86613657]
[ 51908.53292209] [ 51908.53446495] [ 51908.53292207]
[ 7710.71327865] [ 7710.71324194] [ 7710.71327865]
[-16206.9785119] [-16206.97851993] [-16206.9785119]
正如您可以看到正规方程,最小二乘法和np.linalg.solve工具方法在某种程度上给出了不同的结果。问题是为什么这三种方法会产生明显不同的结果,哪种方法可以提供更高效和更准确的结果?
假设: 正规方程方法的结果和 np.linalg.solve 的结果非常接近。 np.linalg.lstsq 的结果与两者不同。由于正规方程使用逆,我们不期望它的结果非常准确,因此np.linalg.solve工具的结果也是如此。似乎 np.linalg.lstsq 给出了更好的结果。
UPD:
Dave Hensley 提到:
行后 np.fill_diagonal(IdentityMatrix, 1)
应添加此代码 IdentityMatrix[0,0] = 0
。
DB2.csv :DB2.csv
DropBox上提供了完整的Python代码:Full code
答案 0 :(得分:13)
专业算法无法求解矩阵逆。它很慢并且引入了不必要的错误。对于小型系统来说,这不是一场灾难,但为什么做一些不理想的事情呢?
基本上,只要你看到数学写成:
x = A^-1 * b
你想要:
x = np.linalg.solve(A, b)
在你的情况下,你需要类似的东西:
XtX_lamb = X.T.dot(X) + lamb * IdentityMatrix
XtY = X.T.dot(Y)
x = np.linalg.solve(XtX_lamb, XtY);
答案 1 :(得分:10)
正如 @Matthew Gunn 所提到的,计算系数矩阵的显式逆作为求解线性方程组的一种方法是不好的做法。直接获得解决方案更快,更准确(see here)。
您看到np.linalg.solve
和np.linalg.lstsq
之间存在差异的原因是因为这些函数对您尝试解决的系统做出了不同的假设,并使用了不同的数值方法。
在幕后,solve
调用DGESV LAPACK routine,后者使用LU分解,然后进行前向和后向替换,以找到{{1>精确解决方案{{1} }}。它要求系统准确确定,即Ax = b
是正方形且满级。
A
会调用DGELSD,它使用lstsq
的奇异值分解来查找最小二乘解决方案。这也适用于过度确定和未确定的情况。
如果您的系统已完全确定,那么您应该使用A
,因为它需要更少的浮点运算,因此更快,更精确。在您的情况下,由于正则化步骤,solve
保证满级。
答案 2 :(得分:3)
其他答案定义了为什么理论上一种计算方法优于其他计算方法。然而,他们没有办法测试哪种解决方案实际上显示出更好的结果。这是:
def test(a, x, b):
res = a.dot(x).as_matrix() - b.as_matrix()
print(np.linalg.norm(res))
test(XtX_lamb, x, XtY)
test(XtX_lamb, th, XtY)
test(XtX_lamb, theta, XtY)
这会影响线性系统误差向量的范数2。 结果是:
np.linalg.solve - 0.000488340357871
np.linalg.lstsq - 1.75520748498
normal equation - 16.1628614202
因此linalg.solve确实显示出最准确的结果。
答案 3 :(得分:1)
我认为您的实施中存在一个影响所有3个计算的错误。您使用以下代码生成IdentityMatrix:
IdentityMatrix= np.zeros((IdentitySize, IdentitySize))
np.fill_diagonal(IdentityMatrix, 1)
(您实际上可以将其简化为IdentityMatrix=np.eye(IdentitySize)
)
单位矩阵是这个(当IdentitySize == 3时):
1 0 0
0 1 0
0 0 1
但你应该使用的是这个(同样的东西,但左上角是0):
0 0 0
0 1 0
0 0 1