根据Doc的交叉熵损失,通过将每个类别的权重与原始损失相乘来计算加权损失。
但是,在pytorch实现中,除非将权重设置为零,否则类权重似乎对最终损失值没有影响。以下是代码:
from torch import nn
import torch
logits = torch.FloatTensor([
[0.1, 0.9],
])
label = torch.LongTensor([0])
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711
# Change class weight for the first class to 0.1
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.1, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 1.1711, should be 0.11711
# Change weight for first class to 0
criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor([0, 1]))
loss = criterion(logits, label)
print(loss.item()) # result: 0
如代码中所示,除非将类权重设置为0,否则它似乎没有任何作用,这种行为与文档相矛盾。
更新
我实现了加权交叉熵的一种版本,在我看来这是“正确”的方法。
import torch
from torch import nn
def weighted_cross_entropy(logits, label, weight=None):
assert len(logits.size()) == 2
batch_size, label_num = logits.size()
assert (batch_size == label.size(0))
if weight is None:
weight = torch.ones(label_num).float()
assert (label_num == weight.size(0))
x_terms = -torch.gather(logits, 1, label.unsqueeze(1)).squeeze()
log_terms = torch.log(torch.sum(torch.exp(logits), dim=1))
weights = torch.gather(weight, 0, label).float()
return torch.mean((x_terms+log_terms)*weights)
logits = torch.FloatTensor([
[0.1, 0.9],
[0.0, 0.1],
])
label = torch.LongTensor([0, 1])
neg_weight = 0.1
weight = torch.FloatTensor([neg_weight, 1])
criterion = nn.CrossEntropyLoss(weight=weight)
loss = criterion(logits, label)
print(loss.item()) # results: 0.69227
print(weighted_cross_entropy(logits, label, weight).item()) # results: 0.38075
我所做的是将批处理中的每个实例与其关联的类权重相乘。结果仍然与原始pytorch实现不同,这使我想知道pytorch如何实际实现这一点。