更改np数组不会自动更改Torch张量吗?

时间:2018-09-17 18:50:00

标签: python numpy pytorch torch tensor

我正在研究PyTorch的基本教程,并遇到了NumPy数组和Torch张量之间的转换。该文档说:

  

Torch Tensor和NumPy数组将共享其基础内存位置,而更改一个将更改另一个。

但是,以下代码似乎并非如此:

import numpy as np

a = np.ones((3,3))
b = torch.from_numpy(a)

np.add(a,1,out=a)
print(a)
print(b)

在上述情况下,我看到更改自动反映在输出中:

[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]], dtype=torch.float64)

但是当我写这样的东西时,不会发生同样的事情:

a = np.ones((3,3))
b = torch.from_numpy(a)

a = a + 1
print(a)
print(b)

我得到以下输出:

[[2. 2. 2.]
 [2. 2. 2.]
 [2. 2. 2.]]
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)

我在这里想念什么?

1 个答案:

答案 0 :(得分:4)

每当您在Python中编写=符号时,您都在创建一个新对象。

因此,在第二种情况下,表达式的右侧使用原始a,然后求值为新对象,即a + 1,它将替换原始a。 b仍然指向原始a的存储位置,但是现在a指向内存中的新对象。

换句话说,在a = a + 1中,表达式a + 1创建了一个新对象,然后Python将该新对象分配为名称a

a += 1中,Python用参数1调用a的就地加法(__iadd__)。

在第一种情况下,numpy代码np.add(a,1,out=a) 负责将该值就地添加到现有数组中。

(感谢@Engineero@Warren Weckesser在评论中指出这些解释)