PyTorch函数中的下划线后缀是什么意思?

时间:2018-10-21 21:33:12

标签: python pytorch

在PyTorch中,张量的许多方法有两种版本-一种带有下划线后缀,一种没有。如果我尝试一下,它们似乎会做同样的事情:

In [1]: import torch

In [2]: a = torch.tensor([2, 4, 6])

In [3]: a.add(10)
Out[3]: tensor([12, 14, 16])

In [4]: a.add_(10)
Out[4]: tensor([12, 14, 16])

两者之间有什么区别

  • torch.addtorch.add_
  • torch.subtorch.sub_
  • ...等等?

2 个答案:

答案 0 :(得分:2)

您已经回答了自己的问题,即下划线表示PyTorch中的就地操作。但是,我想简要指出为什么就地操作会出现问题:

  • 首先,在PyTorch网站上,建议在大多数情况下不要使用就地操作。除非在沉重的内存压力下工作,否则在大多数情况下不使用就地操作会更有效率。
    https://pytorch.org/docs/stable/notes/autograd.html#in-place-operations-with-autograd

  • 第二,使用就地操作时,计算梯度可能会出现问题:

      

    每个张量都保留一个版本计数器,每次都会递增   在任何操作中都被标记为脏污。当函数保存任何张量时   为了向后,将其包含Tensor的版本计数器保存为   好。访问self.saved_tensors后,将对其进行检查,如果已   大于保存的值会引发错误。这样可以确保   您正在使用就地功能并且没有看到任何错误,您可以   确保计算出的梯度正确。    Same source as above.

以下是从您发布的答案中摘录并略作修改的示例:

首先是就地版本:

import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add_(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)

这会导致以下错误:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-c38b252ffe5f> in <module>
      2 a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
      3 adding_tensor = torch.rand(3)
----> 4 b = a.add_(adding_tensor)
      5 c = torch.sum(b)
      6 c.backward()

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

第二个非就地版本:

import torch
a = torch.tensor([2, 4, 6], requires_grad=True, dtype=torch.float)
adding_tensor = torch.rand(3)
b = a.add(adding_tensor)
c = torch.sum(b)
c.backward()
print(c.grad_fn)

哪个可以正常工作-输出:

<SumBackward0 object at 0x7f06b27a1da0>

作为总结,我只是想指出要在PyTorch中谨慎使用就地操作。

答案 1 :(得分:1)

根据documentation,以下划线结尾的方法会在就地上更改张量。这意味着执行该操作不会分配新的内存,该操作通常 increase performance,但can lead to problems and worse performance in PyTorch

In [2]: a = torch.tensor([2, 4, 6])

tensor.add()

In [3]: b = a.add(10)

In [4]: a is b
Out[4]: False # b is a new tensor, new memory was allocated

tensor._add()

In [3]: b = a.add_(10)

In [4]: a is b
Out[4]: True # Same object, no new memory was allocated

请注意,运算符++=也是two different implementations+使用.add()创建新的张量,而+=使用.add_()修改张量

In [2]: a = torch.tensor([2, 4, 6])

In [3]: id(a)
Out[3]: 140250660654104

In [4]: a += 10

In [5]: id(a)
Out[5]: 140250660654104 # Still the same object, no memory allocation was required

In [6]: a = a + 10

In [7]: id(a)
Out[7]: 140250649668272 # New object was created