如何在pytorch中复制`grad_fn`?

时间:2019-05-13 06:43:43

标签: pytorch

>>> print(foo.grad_fn)
<AddBackward0 object at 0x7f7f9f450710>

我想从foo.grad_fn复制到bar.grad_fn。作为参考,不需要foo.data。我只想复制gradient

这可能吗?我尝试了以下操作,但失败了。

>>> bar.grad_fn = foo.grad_fn
AttributeError: attribute 'grad_fn' of 'torch._C._TensorBase' objects is not writable

谢谢。

1 个答案:

答案 0 :(得分:0)

实际上,这很容易。您只需执行N + 1就可以访问存储在叶子张量中的渐变。因此,如果要将渐变从一片叶子复制到另一片叶子,只需在调用N之后执行foo.grad.data。请注意,bar.grad.data.copy_(foo.grad.data)用于避免在计算图中跟踪此操作。

如果不是叶子,则必须在backward中指定可选参数data

但是,我不确定您要达到的目的。我不明白为什么将grad从一个张量复制到另一个张量可能有用。如果您告诉我们更多有关您实际要实现的目标的信息,也许可以给您一个更有用的答案。