pytorch聚焦损耗梯度消失

时间:2020-05-11 19:10:23

标签: python pytorch loss-function

我一直在研究不平衡的二进制分类问题,其中真假比率为9:1,输入的是20个暗表数据。为了处理这个不平衡的数据集,我决定使用Focal loss。我对这一指标的实现是对Tensorflow一个的PyTorch改编。

我进行了几次测试,实现焦距损失的结果与TensorFlow版本产生的结果相同。

class FocalLoss(nn.Module):
    """
    Weighs the contribution of each sample to the loss based in the classification error.
    :gamma: Focusing parameter. gamma=0 is equivalent to BCE_loss
    """
    def __init__(self, gamma, eps=1e-6):
        super(FocalLoss, self).__init__()
        self.gamma = gamma

    def forward(self, y_pred, y_true):
        y_true = y_true.float()
        pred_prob = torch.sigmoid(y_pred)
        ce = nn.BCELoss(reduce=False)(pred_prob,y_true)


        p_t = (y_true*pred_prob)+((1-y_true)*(1-pred_prob))
        modulator = 1.0
        if self.gamma:
            modulator = torch.pow((1.0-p_t),torch.tensor(self.gamma).to(device) )

        return torch.mean(modulator*ce)

我的模型是n隐藏层,在此测试中,n = 4,具有Relu激活的完全连接的NN。当使用交叉熵作为损失函数时,此体系结构工作得相当好。 下图显示了一个历时后的梯度流。验证损失的值接近7e-12,验证精度为50%。

enter image description here

我正在使用adam优化器,其中lr = 1e-4。 您如何看待我对Focal loss的看法?合法吗?

0 个答案:

没有答案
相关问题