假设我有一个矩阵和一个向量,如下所示:
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
有没有办法将其切成x[y]
,所以结果是:
res = [1, 6, 8]
因此,基本上,我采用y
的第一个元素,并采用x
中与第一行和元素的列相对应的元素。
欢呼
答案 0 :(得分:5)
您可以将相应的行索引指定为:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
x[range(x.shape[0]), y]
tensor([1, 6, 8])
答案 1 :(得分:2)
pytorch中的高级索引工作原理与NumPy's
相同,即,索引数组跨轴一起广播。因此,您可以按照FBruzzesi的回答进行操作。
尽管与np.take_along_axis
类似,在pytorch中,您也有torch.gather
,用于沿特定轴取值:
x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])