Pytorch:涉及端到端Jacobian规范的自定义损失

时间:2019-12-02 17:45:42

标签: pytorch loss-function autodiff

Cross posting from Pytorch discussion boards

我想使用修改后的损失函数训练网络,该函数既具有典型的分类损失(例​​如nn.CrossEntropyLoss),又具有端对端Jacobian的Frobenius范数的惩罚(即,如果f (x)是网络的输出\ nabla_x f(x))。

我实现了一个可以使用nn.CrossEntropyLoss成功学习的模型。但是,当我尝试添加第二个损失函数(通过执行两次向后传递)时,我的训练循环会运行,但是模型永远不会学习。此外,如果我计算了端到端的Jacobian,但没有将其包括在损失函数中,则该模型也永远不会学习。在较高级别,我的代码执行以下操作:

  1. 从输入yhat转发传递以获取预测的类x
  2. 致电yhat.backward(torch.ones(appropriate shape), retain_graph=True)
  3. Jacobian规范= x.grad.data.norm(2)
  4. 设置损失等于分类损失+标量系数*雅可比范数
  5. 运行loss.backward()

我怀疑backward()两次运行时的工作方式有误,但是我找不到任何好的资源来阐明这一点。

制作一个有效的例子需要太多的东西,所以我试图提取相关的代码:

def train_model(model, train_dataloader, optimizer, loss_fn, device=None):

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.train()
    train_loss = 0
    correct = 0
    for batch_idx, (batch_input, batch_target) in enumerate(train_dataloader):
        batch_input, batch_target = batch_input.to(device), batch_target.to(device)
        optimizer.zero_grad()
        batch_input.requires_grad_(True)
        model_batch_output = model(batch_input)
        loss = loss_fn(model_output=model_batch_output, model_input=batch_input, model=model, target=batch_target)
        train_loss += loss.item()  # sum up batch loss
        loss.backward()
        optimizer.step()

    def end_to_end_jacobian_loss(model_output, model_input):
        model_output.backward(
            torch.ones(*model_output.shape),
            retain_graph=True)
        jacobian = model_input.grad.data
        jacobian_norm = jacobian.norm(2)
        return jacobian_norm

编辑1:我将以前的实现从.backward()换成autograd.grad,这显然可行!有什么区别?

    def end_to_end_jacobian_loss(model_output, model_input):
        jacobian = autograd.grad(
            outputs=model_output['penultimate_layer'],
            inputs=model_input,
            grad_outputs=torch.ones(*model_output['penultimate_layer'].shape),
            retain_graph=True,
            only_inputs=True)[0]
        jacobian_norm = jacobian.norm(2)
        return jacobian_norm

0 个答案:

没有答案