多层损失的火炬中的梯度行为

时间:2019-05-10 11:09:21

标签: deep-learning pytorch

我有一个损失,每一层都在造成损失。在确保权重正确更新方面,哪种方法正确?

# option 1
x2 = self.layer1(x1)
x3 = self.layer2(x2)
x4 = self.layer3(x3)

在此选项中,我在喂入每个后续块时会分离

    # option 2
    # x2 = self.layer1(x1.detach())
    # x3 = self.layer2(x2.detach())
    # x4 = self.layer3(x3.detach())

共享操作计算出4个损失并将其相加。

    x4 = F.relu(self.bn1(x4))
    loss = some_loss([x1, x2, x3, x4])

1 个答案:

答案 0 :(得分:0)

选项1是正确的。分离张量时,计算历史/图形将丢失,并且不会将梯度传播到输入/分离之前的计算中。

这也可以通过玩具实验看到。

In [14]: import torch                                                                                                                                                                                 

In [15]: x = torch.rand(10,10).requires_grad_()                                                                                                                                                       

In [16]: y = x**2                                                                                                                                                                                     

In [19]: z = torch.sum(y)                                                                                                                                                                             

In [20]: z.backward()                                                                                                                                                                                 

In [23]: x.grad is not None                                                                                                                                                                           
Out[23]: True

使用分离

In [26]: x = torch.rand(10,10).requires_grad_()                                                                                                                                                       

In [27]: y = x**2                                                                                                                                                                                     

In [28]: z = torch.sum(y)                                                                                                                                                                             

In [29]: z_ = z.detach()                                                                                                                                                                              

In [30]: z_.backward()  
# this gives error

这是因为当您调用detach时,它将返回带有复制值的新张量,并且有关先前计算的信息将会丢失。

相关问题