pytorch教程有错误吗?

时间:2018-08-18 13:39:33

标签: python pytorch

pytorch官方教程(https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients)指出out.backward()out.backward(torch.tensor(1))是等效的。但这似乎并非如此。

import torch

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()

# option 1    
out.backward()

# option 2. Replace! do not leave one after the other
# out.backward(torch.tensor(1))

print(x.grad)

使用选项2(注释掉)会导致错误。

注意:请勿留下两个后退电话。将选项1替换为2。

该教程过时了吗?争论的目的是什么?

更新 如果我按照教程中的说明使用out.backward(torch.tensor(1)),则会得到:

E       RuntimeError: invalid gradient at index 0 - expected type torch.FloatTensor but got torch.LongTensor

../../../anaconda3/envs/phd/lib/python3.6/site-packages/torch/autograd/__init__.py:90: RuntimeError

我也尝试使用out.backward(torch.Tensor(1)),但得到了:

E       RuntimeError: invalid gradient at index 0 - expected shape [] but got [1]

../../../anaconda3/envs/phd/lib/python3.6/site-packages/torch/autograd/__init__.py:90: RuntimeError

1 个答案:

答案 0 :(得分:2)

您需要使用dtype=torch.float

import torch

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()

# option 1    
out.backward()
print(x.grad)


x = torch.ones(2, 2, requires_grad=True)
y = x + 2
z = y * y * 3
out = z.mean()



#option 2. Replace! do not leave one after the other
out.backward(torch.tensor(1, dtype=torch.float))

print(x.grad)

输出:

tensor([[ 4.5000,  4.5000],
        [ 4.5000,  4.5000]])
tensor([[ 4.5000,  4.5000],
        [ 4.5000,  4.5000]])