我一直在研究不平衡的二进制分类问题,其中真假比率为9:1,输入的是20个暗表数据。为了处理这个不平衡的数据集,我决定使用Focal loss。我对这一指标的实现是对Tensorflow一个的PyTorch改编。
我进行了几次测试,实现焦距损失的结果与TensorFlow版本产生的结果相同。
class FocalLoss(nn.Module):
"""
Weighs the contribution of each sample to the loss based in the classification error.
:gamma: Focusing parameter. gamma=0 is equivalent to BCE_loss
"""
def __init__(self, gamma, eps=1e-6):
super(FocalLoss, self).__init__()
self.gamma = gamma
def forward(self, y_pred, y_true):
y_true = y_true.float()
pred_prob = torch.sigmoid(y_pred)
ce = nn.BCELoss(reduce=False)(pred_prob,y_true)
p_t = (y_true*pred_prob)+((1-y_true)*(1-pred_prob))
modulator = 1.0
if self.gamma:
modulator = torch.pow((1.0-p_t),torch.tensor(self.gamma).to(device) )
return torch.mean(modulator*ce)
我的模型是n隐藏层,在此测试中,n = 4,具有Relu激活的完全连接的NN。当使用交叉熵作为损失函数时,此体系结构工作得相当好。 下图显示了一个历时后的梯度流。验证损失的值接近7e-12,验证精度为50%。
我正在使用adam优化器,其中lr = 1e-4。 您如何看待我对Focal loss的看法?合法吗?