使用index_select将一个PyTorch张量索引另一个

时间:2017-10-22 15:56:34

标签: matrix indexing pytorch

我有一个3 x 3 PyTorch LongTensor,看起来像这样:

A = 
    [0, 0, 0]
    [1, 2, 2]
    [1, 2, 3]

我希望我们能像这样对4 x 2 FloatTensor进行索引:

B = 
    [0.4, 0.5]
    [1.2, 1.4]
    [0.8, 1.9]
    [2.4, 2.9]

我的预期输出是下面的2 x 3 x 3 FloatTensor:

C[0,:,:] = 
    [0.4, 0.4, 0.4]
    [1.2, 0.8, 0.8]
    [1.2, 0.8, 2.4]

C[1,:,:] =
    [0.5, 0.5, 0.5]
    [1.4, 1.9, 1.9]
    [1.4, 1.9, 2.9]

换句话说,矩阵A是索引和广播矩阵BAB索引的矩阵,因此该操作本质上是一个索引操作。

如何使用torch.index_select()功能完成此操作?如果解决方案涉及添加或排列维度,那很好。

1 个答案:

答案 0 :(得分:0)

使用index_select()要求索引值在向量而不是张量中。但只要格式正确,该函数就会为您处理广播。必须要做的最后一件事是重塑输出,我相信由于广播。

将成功执行此操作的单行

torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)

A.view(-1)对索引矩阵进行矢量化。

__.view(3,-1,2)重新形成索引矩阵的形状,但考虑了大小为2的新额外维度(因为我正在索引N x 2矩阵)。

最后,__.permute(2,0,1)重新整形矩阵,以便输出在单独的通道(而不是每列)中查看B的每个维度。