用另一个多维张量索引多维火炬张量

时间:2020-07-06 19:00:48

标签: python numpy pytorch tensor numpy-slicing

我在火炬中有一个张量 x ,假设形状为(5,3,2,6),而另一个张量 idx 形状为(5,3,2, 1)包含第一个张量中每个元素的索引。我想用第二张量的索引切片第一个张量。我尝试过x = x [idx],但是当我真的希望它的形状为(5,3,2)或(5,3,2,1)时,我得到了一个奇怪的尺寸。

我将尝试举一个简单的例子: 假设

x=torch.Tensor([[10,20,30],
                 [8,4,43]])
idx = torch.Tensor([[0],
                    [2]])

我想要类似的东西

y = x[idx]

使得'y'输出[[10],[43]]或类似的东西。

索引表示所需元素最后一个维度的位置。对于上面的示例,其中x.shape =(2,3)的最后一个维度是列,则“ idx”中的索引是列。我想要这个,但要超过2个维度

2 个答案:

答案 0 :(得分:0)

根据我的评论,您需要idx作为最后一个维度的索引,idx中的每个索引都对应于x中的类似索引(最后一个维度除外) )。在这种情况下(这是Numpy版本,您可以将其转换为手电筒):

ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]

输出:

[[10]
 [43]]

答案 1 :(得分:0)

您可以使用range;和squeeze以获得正确的idx尺寸,例如

x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])

# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
        [43.]])