如何在pytorch中使用索引张量索引中间维度?

时间:2020-10-29 12:09:15

标签: python indexing pytorch tensor

如何索引具有n个维度的张量t和m个index张量,从而保留t的最后一个维度?对于尺寸m之前的所有尺寸,index张量的形状等于张量t。换句话说,我想索引张量的中间维度,同时保留选定索引的以下所有维度。

例如,假设我们有两个张量:

t = torch.randn([3, 5, 2]) * 10
index = torch.tensor([[1, 3],[0,4],[3,2]]).long()

带有t:

tensor([[[ 15.2165,  -7.9702],
         [  0.6646,   5.2844],
         [-22.0657,  -5.9876],
         [ -9.7319,  11.7384],
         [  4.3985,  -6.7058]],

        [[-15.6854, -11.9362],
         [ 11.3054,   3.3068],
         [ -4.7756,  -7.4524],
         [  5.0977, -17.3831],
         [  3.9152, -11.5047]],

        [[ -5.4265, -22.6456],
         [  1.6639,  10.1483],
         [ 13.2129,   3.7850],
         [  3.8543,  -4.3496],
         [ -8.7577, -12.9722]]])

然后我想要的输出将具有(3, 2, 2)的形状,并且是:

tensor([[[  0.6646,   5.2844],
         [ -9.7319,  11.7384]],
        [[-15.6854, -11.9362],
         [  3.9152, -11.5047]],
        [[  3.8543,  -4.3496],
         [ 13.2129,   3.7850]]])

另一个示例是我有一个形状为t的张量(40, 10, 6, 2)和一个形状为(40, 10, 3)的索引张量。这应该查询张量t的维度3,并且预期的输出形状将为(40, 10, 3, 2)

如何在不使用循环的情况下以通用方式实现这一目标?

1 个答案:

答案 0 :(得分:1)

在这种情况下,您可以执行以下操作:

t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]

完整代码:

import torch

t = torch.tensor([[[ 15.2165,  -7.9702],
                   [  0.6646,   5.2844],
                   [-22.0657,  -5.9876],
                   [ -9.7319,  11.7384],
                   [  4.3985,  -6.7058]],
                  [[-15.6854, -11.9362],
                   [ 11.3054,   3.3068],
                   [ -4.7756,  -7.4524],
                   [  5.0977, -17.3831],
                   [  3.9152, -11.5047]],
                  [[ -5.4265, -22.6456],
                   [  1.6639,  10.1483],
                   [ 13.2129,   3.7850],
                   [  3.8543,  -4.3496],
                   [ -8.7577, -12.9722]]])

index = torch.tensor([[1, 3],[0,4],[3,2]]).long()

output = t[torch.arange(t.shape[0]).unsqueeze(1), index, ...]

# tensor([[[  0.6646,   5.2844],
#          [ -9.7319,  11.7384]],
# 
#         [[-15.6854, -11.9362],
#          [  3.9152, -11.5047]],
# 
#         [[  3.8543,  -4.3496],
#          [ 13.2129,   3.7850]]])