pytorch:根据索引张量从3d张量中选择列

时间:2019-09-07 11:03:36

标签: pytorch tensor

我有一个尺寸为M的3D张量[BxLxD]和一个尺寸为idx的1D张量[B,1],其中包含范围为(0, L-1)的列索引。我想创建一个尺寸为N的2D张量[BxD],使得N[i,j] = M[i, idx[i], j]。如何有效地做到这一点?

示例:

B,L,D = 2,4,2

M = torch.rand(B,L,D)

>

tensor([[[0.0612, 0.7385],
         [0.7675, 0.3444],
         [0.9129, 0.7601],
         [0.0567, 0.5602]],

        [[0.5450, 0.3749],
         [0.4212, 0.9243],
         [0.1965, 0.9654],
         [0.7230, 0.6295]]])


idx = torch.randint(0, L, size = (B,))

>

tensor([3, 0])

N = get_N(M, idx)

Expected output:

>

tensor([[0.0567, 0.5602], 
       [0.5450, 0.3749]])

谢谢。

1 个答案:

答案 0 :(得分:1)

import torch

B,L,D = 2,4,2

def get_N(M, idx):
    return M[torch.arange(B), idx, :].squeeze()

M = torch.tensor([[[0.0612, 0.7385],
                   [0.7675, 0.3444],
                   [0.9129, 0.7601],
                   [0.0567, 0.5602]],

                   [[0.5450, 0.3749],
                   [0.4212, 0.9243],
                   [0.1965, 0.9654],
                   [0.7230, 0.6295]]])
idx = torch.tensor([3,0])
N = get_N(M, idx)
print(N)

结果:

tensor([[0.0567, 0.5602],
        [0.5450, 0.3749]])

沿二维切片。