我尝试使用pytorch重现简单的线性回归x = A†b
。但是我得到的数字完全不同。
所以首先我使用普通的numpy来做
A_pinv = np.linalg.pinv(A)
betas = A_pinv.dot(b)
print(((b - A.dot(betas))**2).mean())
print(betas)
结果为:
364.12875
[0.43196774 0.14436531 0.42414093]
现在,我尝试使用pytorch获得足够多的数字:
# re-implement via pytoch model using built-ins
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
# We'll create a TensorDataset, which allows access to rows from inputs and targets as tuples.
# We'll also create a DataLoader, to split the data into batches while training.
# It also provides other utilities like shuffling and sampling.
inputs = to.from_numpy(A)
targets = to.from_numpy(b)
train_ds = TensorDataset(inputs, targets)
batch_size = 5
train_dl = DataLoader(train_ds, batch_size, shuffle=True)
# define model, loss and optimizer
new_model = nn.Linear(source_variables, predict_variables, bias=False)
loss_fn = F.mse_loss
opt = to.optim.SGD(new_model.parameters(), lr=1e-10)
def fit(num_epochs, new_model, loss_fn, opt):
for epoch in tnrange(num_epochs, desc="epoch"):
for xb,yb in train_dl:
# Generate predictions
pred = new_model(xb)
loss = loss_fn(pred, yb)
# Perform gradient descent
loss.backward()
opt.step()
opt.zero_grad()
if epoch % 1000 == 0:
print((new_model.weight, loss))
print('Training loss: ', loss_fn(model(inputs), targets))
# fit the model
fit(10000, new_model, loss_fn, opt)
它作为最后的结果打印:
tensor([[0.0231, 0.5185, 0.4589]], requires_grad=True), tensor(271.8525, grad_fn=<MseLossBackward>))
Training loss: tensor(378.2871, grad_fn=<MseLossBackward>)
您可以看到这些数字完全不同,所以我一定在某个地方犯了一个错误...
以下是A
和b
的数字,用于再现结果:
A = np.array([[2822.48, 2808.48, 2810.92],
[2832.94, 2822.48, 2808.48],
[2832.57, 2832.94, 2822.48],
[2824.23, 2832.57, 2832.94],
[2854.88, 2824.23, 2832.57],
[2800.71, 2854.88, 2824.23],
[2798.36, 2800.71, 2854.88],
[2818.46, 2798.36, 2800.71],
[2805.37, 2818.46, 2798.36],
[2815.44, 2805.37, 2818.46]], dtype=float32)
b = np.array([2832.94, 2832.57, 2824.23, 2854.88, 2800.71, 2798.36, 2818.46, 2805.37, 2815.44, 2834.4 ], dtype=float32)