使用缩减功能,使用pytorch计算损耗很容易:
criterion = nn.CrossEntropyLoss(weight=cw, reduction="mean")
其中cw
是一类权重的一维张量。
masks_pred = net(imgs)
loss = criterion(masks_pred, masks)
在这种情况下,loss
是表示加权损失的标量(在loss.item()
之后),可以与使用以下内容相同:
criterion = nn.CrossEntropyLoss(weight=cw, reduction="none")
masks_pred = net(imgs)
loss = criterion(masks_pred, masks)
loss = loss.sum() / cw[masks].sum()
我的问题是如何获得每个班级的加权损失?从理论上讲,我知道我需要计算以下内容:
x类损失总和/ x类重量之和
我尝试过的事情:
我尝试做的第一件事是获得特定类权重的总和超过masks
(例如,类0):
cw[masks==0]
但我收到此错误:
IndexError:维度1的张量的索引过多
我也不知道每班损失的总和。