def focalloss(input, target):
"""
compute focal loss for multi-classes classification
:param input: logits not processed by softmax in shape [batch, channel(classes)]
:param target: in shape [batch], long data_type
:return: focal loss
"""
alpha = 0.5
alpha_factor = torch.ones(target.shape).cuda() * alpha
alpha_factor = torch.where(torch.eq(target, 1.), alpha_factor, 1. - alpha_factor)
gamma = 2
input_softmax = F.softmax(input, dim=1)
index = target.unsqueeze(dim=1)
pred_weights = torch.gather(input_softmax, dim=1, index=index).squeeze()
focalweights = torch.ones_like(pred_weights) - pred_weights
focalweights = alpha_factor * torch.pow(focalweights, gamma)
focalweights = focalweights.detach()
cel = F.cross_entropy(input, target, reduction='none')
assert len(target.shape) == 1
fl = (focalweights*cel).sum()/target.numel()
return fl
当我使用此工具训练模型时。前者的损失 几个时代,但以后会上升...我不知道为什么会这样, 希望有人可以帮助我。