我对PyTorch中的向后功能有疑问

时间:2019-07-29 07:13:17

标签: pytorch

我对pytorch的向后功能有疑问,我认为我没有得到正确的输出

import numpy as np
import torch
from torch.autograd import Variable
a = Variable(torch.FloatTensor([[1,2,3],[4,5,6]]), requires_grad=True) 
out = a * a
out.backward(a)
print(a.grad)

输出为

tensor([[ 2.,  8., 18.],
        [32., 50., 72.]])

也许是2*a*a

但是我认为输出应该是

tensor([[ 1.,  2., 6.],
        [8., 10., 12.]])

2*a.引起d(x^2)/dx=2x

1 个答案:

答案 0 :(得分:6)

请仔细阅读backward()上的文档,以更好地理解它。

默认情况下,pytorch期望为网络的 last 输出(损失函数)调用backward()。损失函数总是输出标量,因此,标量损失与所有其他变量/参数的梯度都得到了很好的定义(使用链式规则)。
因此,默认情况下,backwards()在标量张量上调用,并且不包含任何参数。
例如:

a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
  for j in range(3):
    out = a[i,j] * a[i,j]
    out.backward()
print(a.grad)

收益

tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])

符合预期:d(a^2)/da = 2a

但是,当您在2×3 backwards张量(不再是标量函数)上调用out时,您期望a.grad是什么?实际上,您需要一个2×3×2×3的输出:d out[i,j] / d a[k,l](!)
Pytorch不支持此非标量函数派生类。
取而代之的是,pytorch假设out只是一个中间张量,并且在“上游”某处存在一个标量损失函数,该函数通过链式规则提供了d loss/ d out[i,j]。此“上游”梯度的大小为2×3,在这种情况下,实际上是您提供的参数backwardout.backward(g),其中g_ij = d loss/ d out_ij
然后根据链式规则d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
计算梯度 由于您提供了a作为“上游”渐变,因此

a.grad[i,j] = 2 * a[i,j] * a[i,j]

如果您要提供全部的“上游”渐变

out.backward(torch.ones(2,3))
print(a.grad)

收益

tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.]])

符合预期。

这都是连锁法则。