如何重现通过pytorch中的伪逆完成的线性回归

时间:2019-03-30 13:50:19

标签: linear-regression pytorch

我尝试使用pytorch重现简单的线性回归x = A†b。但是我得到的数字完全不同。

所以首先我使用普通的numpy来做

A_pinv = np.linalg.pinv(A)
betas = A_pinv.dot(b)

print(((b - A.dot(betas))**2).mean())
print(betas)

结果为:

364.12875
[0.43196774 0.14436531 0.42414093]

现在,我尝试使用pytorch获得足够多的数字:

# re-implement via pytoch model using built-ins
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

# We'll create a TensorDataset, which allows access to rows from inputs and targets as tuples.
# We'll also create a DataLoader, to split the data into batches while training. 
# It also provides other utilities like shuffling and sampling.

inputs = to.from_numpy(A)
targets = to.from_numpy(b)
train_ds = TensorDataset(inputs, targets)

batch_size = 5
train_dl = DataLoader(train_ds, batch_size, shuffle=True)

# define model, loss and optimizer
new_model = nn.Linear(source_variables, predict_variables, bias=False)
loss_fn = F.mse_loss
opt = to.optim.SGD(new_model.parameters(), lr=1e-10)

def fit(num_epochs, new_model, loss_fn, opt):
    for epoch in tnrange(num_epochs, desc="epoch"):
        for xb,yb in train_dl:
            # Generate predictions
            pred = new_model(xb)
            loss = loss_fn(pred, yb)
            # Perform gradient descent
            loss.backward()
            opt.step()
            opt.zero_grad()

        if epoch % 1000 == 0:
            print((new_model.weight, loss))
    print('Training loss: ', loss_fn(model(inputs), targets))

# fit the model
fit(10000, new_model, loss_fn, opt)

它作为最后的结果打印:

tensor([[0.0231, 0.5185, 0.4589]], requires_grad=True), tensor(271.8525, grad_fn=<MseLossBackward>))
Training loss:  tensor(378.2871, grad_fn=<MseLossBackward>)

您可以看到这些数字完全不同,所以我一定在某个地方犯了一个错误...

以下是Ab的数字,用于再现结果:

A = np.array([[2822.48, 2808.48, 2810.92],
             [2832.94, 2822.48, 2808.48],
             [2832.57, 2832.94, 2822.48],
             [2824.23, 2832.57, 2832.94],
             [2854.88, 2824.23, 2832.57],
             [2800.71, 2854.88, 2824.23],
             [2798.36, 2800.71, 2854.88],
             [2818.46, 2798.36, 2800.71],
             [2805.37, 2818.46, 2798.36],
             [2815.44, 2805.37, 2818.46]], dtype=float32)

b = np.array([2832.94, 2832.57, 2824.23, 2854.88, 2800.71, 2798.36, 2818.46, 2805.37, 2815.44, 2834.4 ], dtype=float32)

0 个答案:

没有答案