我的焦点损失实施过程中是否存在一些错误?

时间:2019-12-13 01:38:13

标签: python pytorch loss-function

def focalloss(input, target):
  """
  compute focal loss for multi-classes classification
  :param input: logits not processed by softmax in shape [batch, channel(classes)]
  :param target: in shape [batch], long data_type
  :return: focal loss
  """

  alpha = 0.5
  alpha_factor = torch.ones(target.shape).cuda() * alpha
  alpha_factor = torch.where(torch.eq(target, 1.), alpha_factor, 1. - alpha_factor)

  gamma = 2
  input_softmax = F.softmax(input, dim=1)
  index = target.unsqueeze(dim=1)
  pred_weights = torch.gather(input_softmax, dim=1, index=index).squeeze()
  focalweights = torch.ones_like(pred_weights) - pred_weights

  focalweights = alpha_factor * torch.pow(focalweights, gamma)
  focalweights = focalweights.detach()


  cel = F.cross_entropy(input, target, reduction='none')
  assert len(target.shape) == 1

  fl  = (focalweights*cel).sum()/target.numel()

return fl
  

当我使用此工具训练模型时。前者的损失   几个时代,但以后会上升...我不知道为什么会这样,   希望有人可以帮助我。

0 个答案:

没有答案