pytorch,我怎样才能使张量模型(x)和答案(x)的大小相同?

时间:2018-12-01 08:22:20

标签: python pytorch

我正在尝试建立一个简单的线性模型来预测公式的参数。

y = 3 * x1 + x2-2 * x3

不幸的是,当我尝试计算损失时,出现了一些问题。

def answer(x):
return 3 * x[:,0] + x[:,1] - 2 * x[:,2]


def loss_f(x):  

y = answer(x)
y_hat = model(x)    
loss = ((y - y_hat).pow(2)).sum() / x.size(0)

return loss

当我将batch_size设置为3时,每个结果的大小都不同

x = torch.randn(3,3)
answer(x)
tensor([ 2.0201, -3.8354,  2.0059])

model(x)
tensor([[ 0.2085],
    [-0.0670],
    [-1.3635]], grad_fn=<ThAddmmBackward>)

answer(x.data).size()
torch.Size([3])
model(x.data).size()
torch.Size([3, 1])

我认为广播是自动应用的。

loss = ((y - y_hat).pow(2)).sum() / x.size(0)

我如何使两个张量的大小相同?谢谢
这是我的代码

import torch
import torch.nn as nn
import torch.optim as optim

class model(nn.Module):

    def __init__(self, input_size, output_size):
        super(model, self).__init__()

        self.linear = nn.Linear(input_size, output_size)

        def forward(self, x):

        y = self.linear(x)

        return y

model = model(3,1)
optimizer = optim.SGD(model.parameters(), lr = 0.001, momentum=0.1)

print('Parameters : ')
for p in model.parameters():
    print(p)

print('')
print('Optimizer : ')
print(optimizer)

def generate_data(batch_size):
    x = torch.randn(batch_size, 3)

    return x

def answer(x):

    return 3 * x[:,0] + x[:,1] - 2 * x[:,2]

def loss_f(x):

    y = answer(x)
    y_hat = model(x)

    loss = ((y - y_hat).pow(2)).sum() / x.size(0)

    return loss

x = torch.randn(3,3)
print(x)
x = torch.FloatTensor(x)

batch_size = 3
epoch_n = 1000
iter_n = 100

for epoch in range(epoch_n):
    avg_loss = 0

    for i in range(iter_n):
        x = torch.randn(batch_size, 3)

        optimizer.zero_grad()
        loss = loss_f(x.data)
        loss.backward()
        optimizer.step()

        avg_loss += loss

    avg_loss = avg_loss / iter_n

    x_valid = torch.FloatTensor([[1,2,3]])
    y_valid = answer(x_valid)

    model.eval()
    y_hat = model(x_valid)
    model.train()

    print(avg_loss, y_valid.data[0], y_hat.data[0])

    if avg_loss < 0.001:
        break

1 个答案:

答案 0 :(得分:1)

您可以使用Tensor.view

https://pytorch.org/docs/stable/tensors.html#torch.Tensor.view

类似

answer(x.data).view(-1, 1)

应该可以解决问题。