我正在尝试在PyTorch中编写自定义损失函数(降噪损失)。它与交叉熵损失非常相似,不同之处在于,它假定预测的答案中的某些标签不正确,从而使它对所预测的答案具有一定的置信度(预测矩阵中的最高概率)。这里pred表示预测的[m * L]矩阵,其中m是示例数,L是标签数,y_true是实际标签的[m * 1]矩阵,“ ro”是决定每个标签的影响的超参数所使用的两个标准中的一个。
def lossNR(pred, y_true, ro):
outputs = torch.log(pred) # compute the log of softmax values
out1 = outputs.gather(1, y_true.view([-1,1])) # pick the values corresponding to the labels
l1 = -((ro)* torch.mean(out1))
l2 = -(1-ro) * torch.mean((torch.max(outputs,1)[0]))
print("l1=", l1)
print("l2 = ", l2)
return (l1+l2)
我在各种数据集上尝试了损失函数,但效果不好。请提供建议。