a = tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 8, 10],
[11, 12, 13, 14, 15]])
我有一个torch
张量,我需要索引一个张量c
使得c = [[3], [8], [13]]
因此,我做了c = a[:,[2]]
,它给了我预期的答案,但是在自动分级机上仍然失败。
自动分级机使用如下检查功能-
def check(orig, actual, expected):
expected = torch.tensor(expected)
same_elements = (actual == expected).all().item() == 1
same_storage = (orig.storage().data_ptr() == actual.storage().data_ptr())
return same_elements and same_storage
print('c correct:', check(a, c, [[3], [8], [13]]))
我尝试对其进行调试,结果发现same_storage
是错误的,我不知道为什么orig.storage().data_ptr() == actual.storage().data_ptr()
应该是True
,以及它如何产生作用。
更新
通过执行c = a[:, 2:3]
而不是c = a[:, [2]]
可以得到正确的答案。
答案 0 :(得分:1)
PyTorch允许张量成为现有张量的“视图”,以便它与基本张量共享相同的基础数据,从而避免显式数据复制以执行快速且内存高效的操作。
如Tensor View docs中所述,
通过索引访问张量的内容时,PyTorch遵循Numpy行为,即基本索引返回视图,而高级索引返回副本。
在您的示例中,c = a[:, 2:3]
是基本索引,而c = a[:, [2]]
是高级索引。这就是仅在第一种情况下创建视图的原因。因此,.storage().data_ptr()
得到相同的结果。
您可以在Numpy indexing docs中阅读有关基本索引和高级索引的信息。
高级索引被触发时选择对象obj是非元组序列对象,(数据类型整数或布尔的)一个ndarray,或者与序列对象或ndarray(数据类型中的至少一个整数或一个元组布尔)。