根据索引张量从张量中提取子张量

时间:2021-04-16 10:55:33

标签: python pytorch tensor

我有这个张量:

tensor([[[1, 2],
         [3, 4]],
        [[5, 6],
         [7, 8]]])

我有这个索引张量:

tensor([0, 1])

而我想要得到的是根据dim 1和索引张量中相应索引的子张量,即:

tensor([[1, 2],
        [7, 8]])

尝试使用 torch.gather() 函数和高级索引没有成功,有人可以帮忙吗?

1 个答案:

答案 0 :(得分:1)

您隐式地使用了索引张量的每个值的索引。它们恰好与值相同。如果你想遍历第一级,张量的元素,你可以使用 torch.arange 来构造第一级索引。

import torch
from torch import tensor

t = tensor([[[1, 2],
             [3, 4]],
            [[5, 6],
             [7, 8]]])

ix = tensor([0, 1])
ix0 = torch.arange(0, ix.shape.numel())

t[ix0, ix]
# returns:
tensor([[1, 2],
        [7, 8]])