PyTorch:CrossEntropyLoss,更改类权重不会更改计算出的损失

时间:2019-01-19 09:22:21

标签: machine-learning pytorch

根据Doc的交叉熵损失,通过将每个类别的权重与原始损失相乘来计算加权损失。

但是,在pytorch实现中,除非将权重设置为零,否则类权重似乎对最终损失值没有影响。以下是代码:

from torch import nn
import torch

logits = torch.FloatTensor([
    [0.1, 0.9],
])
label = torch.LongTensor([0])

criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711

# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711

# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0

如代码中所示,除非将类权重设置为0,否则它似乎没有任何作用,这种行为与文档相矛盾。

更新
我实现了加权交叉熵的一种版本,在我看来这是“正确”的方法。

import torch
from torch import nn

def weighted_cross_entropy(logits, label, weight=None):
    assert len(logits.size()) == 2
    batch_size, label_num = logits.size()
    assert (batch_size == label.size(0))

    if weight is None:
        weight = torch.ones(label_num).float()

    assert (label_num == weight.size(0))

    x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
    log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))

    weights = torch.gather(weight, 0, label).float()

    return torch.mean((x_terms+log_terms)*weights)

logits = torch.FloatTensor([
    [0.1, 0.9],
    [0.0, 0.1],

])

label = torch.LongTensor([0, 1])

neg_weight = 0.1

weight = torch.FloatTensor([neg_weight, 1])

criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)

print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075

我所做的是将批处理中的每个实例与其关联的类权重相乘。结果仍然与原始pytorch实现不同,这使我想知道pytorch如何实际实现这一点。

0 个答案:

没有答案