PyTorch张量高级索引

时间:2020-04-08 08:33:08

标签: python numpy pytorch

假设我有一个矩阵和一个向量,如下所示:

x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

有没有办法将其切成x[y],所以结果是:

res = [1, 6, 8]

因此,基本上,我采用y的第一个元素,并采用x中与第一行和元素的列相对应的元素。

欢呼

2 个答案:

答案 0 :(得分:5)

您可以将相应的行索引指定为:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])

答案 1 :(得分:2)

pytorch中的高级索引工作原理与NumPy's相同,即,索引数组跨轴一起广播。因此,您可以按照FBruzzesi的回答进行操作。

尽管与np.take_along_axis类似,在pytorch中,您也有torch.gather,用于沿特定轴取值:

x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])