在pytorch中,y.backward([0.1,1.0,0.0001])

时间:2018-03-23 02:22:49

标签: pytorch

在pytorch y.backward([0.1, 1.0, 0.0001])的含义是什么?

我理解y.backward()意味着做反向传播。 但是[0.1, 1.0, 0.0001]y.backward([0.1, 1.0, 0.0001])的含义是什么?

3 个答案:

答案 0 :(得分:2)

表达式apiVersion: 1 deleteDatasources: - name: NameOfDataSource orgId: 1 datasources: - name: "NameOfDataSource" type: "postgres" access: "proxy" url: "172.17.0.4:5432" user: "usernamme" password: "passwordOfUser" database: "database" basicAuth: false isDefault: false jsonData: {sslmode: "disable"} readOnly: false editable: true 实际上是错误的。它应该是y.backward([0.1, 1.0, 0.0001]),其中y.backward(torch.Tensor([0.1, 1.0, 0.0001]))是将计算导数的变量。

示例

torch.Tensor([0.1, 1.0, 0.0001])

此处x = Variable(torch.ones(2, 2), requires_grad=True) y = (x + 2).mean() y.backward(torch.Tensor([1.0])) print(x.grad) 以及y = (x + 2)/4 dy/dx_i = 0.25以来x_i = 1.0。另请注意,y.backward(torch.Tensor([1.0]))y.backward()是等效的。

如果你这样做:

y.backward(torch.Tensor([0.1]))
print(x.grad)

打印:

Variable containing:
1.00000e-02 *
  2.5000  2.5000
  2.5000  2.5000
[torch.FloatTensor of size 2x2]

只是0.1 * 0.25 = 0.025。所以,现在如果你计算:

y.backward(torch.Tensor([0.1, 0.01]))
print(x.grad)

然后打印:

Variable containing:
1.00000e-02 *
  2.5000  0.2500
  2.5000  0.2500
[torch.FloatTensor of size 2x2]

其中,dy/dx_11 = dy/d_x21 = 0.025dy/dx_12 = dy/d_x22 = 0.0025

查看backward()的函数原型。您可以考虑查看此example

答案 1 :(得分:0)

首先,它不是y.backward([0.1, 1.0, 0.0001],因为在pyTorch中,任何参数都应为Tensor。因此,正确的应该是y.backward(torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float))。 使用此处的链接检查文档autograd

第二,此torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)创建一个包含3个元素的1-d张量。代码y.backward(torch.tensor([0.1, 1.0, 0.0001])实际上是在与y计算向量乘积。

答案 2 :(得分:0)

这是输出为矢量时的自动分级。我们无法获得矢量的梯度,我们需要将此矢量转换为缩放器。梯度参数是将这个矢量转换为缩放器的权重。

例如,输入:x = [x1,x2,x3]和运算:y = 2 * x = [2 * x1,2 * x2,2 * x3]

然后,我们无法获得dy / dx。如果我们有y.backward(torch.tensor([0.1,1,0.001])),则意味着我们还有另一个变量: output = torch.sum(y * [0.1,1,0.001])= 0.2 * x1 + 2 * x2 + 0.002 * x3。

然后,我们可以得到d(out)/ dx,并且d(out)/ dx将存储在x.grad中。在我们的示例中,x.grad = [d(out)/ dx1,d(out)/ dx2,d(out)/ dx3] = [0.2,2,0.002]。