PyTorch 张量:基于旧张量和索引的新张量

时间:2021-03-31 09:30:27

标签: python pytorch torch

我是张量的新手,对这个问题很头疼:

我有一个大小为 k 的索引张量,其值在 0 到 k-1 之间:

tensor([0,1,2,0])

和以下矩阵:

tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [4, 9]]])

我想创建一个新的张量,其中包含索引中指定的行,按顺序排列。所以我想要:

tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [0, 9]]])

外部张量我或多或少会像这样执行此操作:

new_matrix = [matrix[i] for i in index]

如何在 PyTorch 中对张量执行类似的操作?

1 个答案:

答案 0 :(得分:1)

您使用 fancy indexing

from torch import tensor

index = tensor([0,1,2,0])
t = tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [0, 9]]])

result = t[:, index, :]

得到

tensor([[[0, 9],
         [1, 8],
         [2, 3],
         [0, 9]]])

注意 t.shape == (1, 4, 2) 并且您希望在 second 轴上建立索引;所以我们将它应用到第二个参数中,并通过 :s 即 [:, index, :] 保持其余部分不变。