为什么损失函数可以应用于不同大小的张量

时间:2018-08-13 12:58:16

标签: pytorch

例如,我有一个网络,以张量[N,7](N是样本数)作为输入,张量[N,4]作为输出,“ 4”代表不同类的概率。

训练数据的标签为张量[N]的形式,范围为0到3(表示真实级别)。

这是我的问题,我看过一些演示,他们将损失函数直接应用于输出张量和标签张量。我想知道为什么这种方法行得通,因为它们的大小不同,而且大小似乎不符合“广播语义”。

这是最小的演示。

import torch
import torch.nn as nn
import torch.optim as optim

if __name__ == '__main__':
    features = torch.randn(2, 7)
    gt = torch.tensor([1, 1])
    model = nn.Sequential(
        nn.Linear(7, 4),
        nn.ReLU(),
        nn.Linear(4, 4)
    )
    optimizer = optim.SGD(model.parameters(), lr=0.005)
    f = nn.CrossEntropyLoss()

    for epoch in range(1000):
        optimizer.zero_grad()
        output = model(features)
        loss = f(output, gt)
        loss.backward()
        optimizer.step()

1 个答案:

答案 0 :(得分:1)

在PyTorch中,实现为:

pytorch cross entropy

文档链接:https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss


因此在pytorch中实现此公式,您将得到:

import torch
import torch.nn.functional as F

output = torch.tensor([ 0.1998, -0.2261, -0.0388,  0.1457])
target = torch.LongTensor([1])

# implementing the formula above
print('manual  cross-entropy:', (-output[target] + torch.log(torch.sum(torch.exp(output))))[0])

# calling build in cross entropy function to check the result
print('pytorch cross-entropy:', F.cross_entropy(output.unsqueeze(0), target))

输出:

manual  cross-entropy: tensor(1.6462)
pytorch cross-entropy: tensor(1.6462)

我希望这对您有所帮助,对不起您造成的困惑。