为什么在不带分离的Pytorch中计算图不知道张量的情况下,为什么能够更改张量的值?

时间:2020-06-16 18:26:40

标签: python pytorch

我可以更改需要grad的张量的值,而无需autograd知道它:

def error_unexpected_way_to_by_pass_safety():
    import torch 
    a = torch.tensor([1,2,3.], requires_grad=True)
    # are detached tensor's leafs? yes they are
    a_detached = a.detach()
    #a.fill_(2) # illegal, warns you that a tensor which requires grads is used in an inplace op (so it won't be recorded in computation graph so it wont take the right derivative of the forward path as this op won't be in it)
    a_detached.fill_(2) # weird that this one is allowed, seems to allow me to bypass the error check from the previous comment...?!
    print(f'a = {a}')
    print(f'a_detached = {a_detached}')
    a.sum().backward()

这不会引发任何错误。虽然,我能够更改a的内容,这是需要grad的张量,而autograd却不知道。这意味着计算图不知道该操作(用2填充)。这似乎是错误的。谁能阐明正在发生的事情?

2 个答案:

答案 0 :(得分:3)

.detach使您可以查看相同的数据,因此修改分离张量的数据会修改原始数据。您可以这样检查:

a.data_ptr() == a_detached.data_ptr() # True

对于为什么,这是.detach的实现方式(与防御性复制相反),这是只有PyTorch作者才知道答案的设计问题。我认为这是为了保存不必要的副本,但是用户则需要意识到,如果他们想就地修改分离的张量,必须自己复制张量。

请注意,如果您确实想更改,也可以更改非分离张量:

a.data.fill_(2)

PyTorch并未试图阻止您“入侵”自动毕业;用户仍然必须知道如何正确使用张量,以便可以正确跟踪渐变。

答案 1 :(得分:1)

在此处添加到现有答案中。 detach不复制数据的原因绝对是为了节省不必要的副本-如果您想拥有完整副本,则可以始终使用a.clone().detach()版的a(或{{1} })。在某些情况下,您只能执行其中一种操作(例如,a.detach().clone()clone),并且所有这些含义都可以。

一个人想不使用detach而使用detach的最重要原因是因为这是在pytorch中实现所谓的“ StopGradient”操作的方法(tf中的stop_gradient)。想象一下您想在NN中两次使用张量clone的情况,使得梯度在一种情况下传播通过而在另一种情况下不传播(并且预计没人会在适当位置修改张量)

关于a不带clone的情况-似乎有点不寻常,但是我已经看到了这样的示例(大多数人希望确保原始张量不会被更新,但是渐变会传播给它。)

通常要避免就地修改张量(优化器步骤除外)。在任何情况下,您都必须格外小心,因为这是一种使计算图无效的简单方法。梯度计算。