为什么蒸馏损失与交叉熵损失相比不会收敛(减少)?

时间:2020-02-27 06:45:43

标签: python deep-learning pytorch

我将交叉熵和蒸馏损失用于持续学习(又称为增量学习)模型。 但是,蒸馏损失不会收敛,而交叉熵损失会收敛。

这是我的蒸馏损失代码和培训部分。

蒸馏损失代码:

def loss_fn_distillation(outputs, soft_labels, temperature, current_step, total_step, total_label):
    current_label = (total_label / total_step) * (current_step + 1)
    previous_label = (total_label / total_step) * current_step

    soft_labels = V(soft_labels.data, requires_grad=False).cuda()
    soft_labels = torch.softmax(soft_labels / temperature, dim=1)

     outputs = F.log_softmax(outputs[:,:-int(current_label-previous_label)]/temperature, dim = 1)

     distill_loss = torch.sum(outputs * soft_labels, dim=1, keepdim=False)
     distill_loss = -torch.mean(distill_loss, dim=0, keepdim=False)


      return V(distill_loss, requires_grad=True).cuda()

培训部分的代码:

    outputs = net(inputs)
    ce_loss = criterion(outputs, targets)

    if(i>0) :

        soft_label = previous_net(inputs)

        distill_loss = loss_fn_distillation(outputs=outputs, soft_labels=soft_label, temperature=2,
                                            current_step=i, total_step=step, total_label=number_label)
        print(ce_loss, distill_loss)
        loss = distill_loss + ce_loss


    else :
        loss = ce_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

历时交叉熵损失和蒸馏损失的结果:

enter image description here

感谢您的任何反馈。谢谢。

0 个答案:

没有答案