如何计算相对于输入的损失梯度?

时间:2021-07-11 17:06:06

标签: python pytorch

我有一个预训练的 PyTorch 模型。我需要使用此模型计算损失相对于网络输入的梯度(无需再次训练,仅使用预训练模型)。

我写了下面的代码,但我不确定它是否正确。

    test_X, test_y = load_data(mode='test')
    testset_original = MyDataset(test_X, test_y, transform=default_transform)
    testloader = DataLoader(testset_original, batch_size=32, shuffle=True)

    model = MyModel(device=device).to(device)
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])

    gradient_losses = []
    for i, data in enumerate(testloader):
        inputs, labels = data
        inputs= inputs.to(device)
        labels = labels.to(device)
        inputs.requires_grad = True
        output = model(inputs)
        loss = loss_function(output)
        loss.backward()
        gradient_losses.append(inputs.grad)

我的问题是,这个列表 gradient_losses 是否真的存储了我想要存储的内容?如果不是,那么正确的方法是什么?

1 个答案:

答案 0 :(得分:1)

<块引用>

这个列表 gradient_losses 真的存储了我想要存储的东西吗?

是的,如果您想获得损失相对于输入的导数,那么这似乎是正确的方法。这是最小的示例,以 f(x) = a*x 为例。然后df/dx = a

>>> x = torch.rand(10, requires_grad=True)
>>> y = torch.rand(10)
>>> a = torch.tensor([3.], requires_grad=True)

>>> loss = a*x - y
>>> loss.mean().backward()

>>> x.grad
tensor([0.3000, 0.3000, ..., 0.3000, 0.3000])

其中,在这种情况下等于 a / len(x)

请注意,您使用 input.grad 提取的每个梯度都将在整个批次中取平均值,而不是每个单独输入的梯度。


此外,您不需要 .clone() 输入梯度,因为它们不是模型的一部分,并且不会被 model.zero_grad() 归零。