使用权重计算多类细分中每个类的损失

时间:2020-10-15 11:15:23

标签: python machine-learning pytorch image-segmentation

使用缩减功能,使用pytorch计算损耗很容易:

criterion = nn.CrossEntropyLoss(weight=cw, reduction="mean")

其中cw是一类权重的一维张量。

masks_pred = net(imgs)
loss = criterion(masks_pred, masks)

在这种情况下,loss是表示加权损失的标量(在loss.item()之后),可以与使用以下内容相同:

criterion = nn.CrossEntropyLoss(weight=cw, reduction="none")
masks_pred = net(imgs)
loss = criterion(masks_pred, masks)
loss = loss.sum() / cw[masks].sum()

我的问题是如何获得每个班级的加权损失?从理论上讲,我知道我需要计算以下内容:

x类损失总和/ x类重量之和

我尝试过的事情:

我尝试做的第一件事是获得特定类权重的总和超过masks(例如,类0):

cw[masks==0]

但我收到此错误:

IndexError:维度1的张量的索引过多

我也不知道每班损失的总和。

0 个答案:

没有答案