切片pytorch张量和使用data_ptr()

时间:2020-06-27 08:52:49

标签: arrays numpy pytorch torch

a = tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  8, 10],
        [11, 12, 13, 14, 15]])

我有一个torch张量,我需要索引一个张量c使得c = [[3], [8], [13]]

因此,我做了c = a[:,[2]],它给了我预期的答案,但是在自动分级机上仍然失败。 自动分级机使用如下检查功能-

def check(orig, actual, expected):
  expected = torch.tensor(expected)
  same_elements = (actual == expected).all().item() == 1
  same_storage = (orig.storage().data_ptr() == actual.storage().data_ptr())
  return same_elements and same_storage

print('c correct:', check(a, c, [[3], [8], [13]]))

我尝试对其进行调试,结果发现same_storage是错误的,我不知道为什么orig.storage().data_ptr() == actual.storage().data_ptr()应该是True,以及它如何产生作用。

更新 通过执行c = a[:, 2:3]而不是c = a[:, [2]]可以得到正确的答案。

1 个答案:

答案 0 :(得分:1)

PyTorch允许张量成为现有张量的“视图”,以便它与基本张量共享相同的基础数据,从而避免显式数据复制以执行快速且内存高效的操作。

Tensor View docs中所述,

通过索引访问张量的内容时,PyTorch遵循Numpy行为,即基本索引返回视图,而高级索引返回副本。

在您的示例中,c = a[:, 2:3]是基本索引,而c = a[:, [2]]是高级索引。这就是仅在第一种情况下创建视图的原因。因此,.storage().data_ptr()得到相同的结果。

您可以在Numpy indexing docs中阅读有关基本索引和高级索引的信息。

高级索引被触发时选择对象obj是非元组序列对象,(数据类型整数或布尔的)一个ndarray,或者与序列对象或ndarray(数据类型中的至少一个整数或一个元组布尔)。