我需要在Cifar10上训练一个分类器,以便它可以对标签为0和1的图像进行分类,而对于这两个标签以外的其他标签,分类器可提供无偏(中性)结果。我尝试使用以下代码进行BCE损失训练:
input = model(x)
not_ind_tgt = ~((target==0) | (target==1))
tgt_hot = F.one_hot(target, 10).float()
multi_label = 1./8. * torch.ones(1,10).float()
multi_label[0, 0: 2] = 0.0
tgt_hot[not_ind_tgt] = multi_label
loss = nn.BCEWithLogitsLoss()(input, tgt_hot).mean()
但是,经过训练的分类器在0和1类上的准确性为零。