我有一个3 x 3 PyTorch LongTensor,看起来像这样:
A =
[0, 0, 0]
[1, 2, 2]
[1, 2, 3]
我希望我们能像这样对4 x 2 FloatTensor进行索引:
B =
[0.4, 0.5]
[1.2, 1.4]
[0.8, 1.9]
[2.4, 2.9]
我的预期输出是下面的2 x 3 x 3 FloatTensor:
C[0,:,:] =
[0.4, 0.4, 0.4]
[1.2, 0.8, 0.8]
[1.2, 0.8, 2.4]
C[1,:,:] =
[0.5, 0.5, 0.5]
[1.4, 1.9, 1.9]
[1.4, 1.9, 2.9]
换句话说,矩阵A
是索引和广播矩阵B
。 A
是B
索引的矩阵,因此该操作本质上是一个索引操作。
如何使用torch.index_select()
功能完成此操作?如果解决方案涉及添加或排列维度,那很好。
答案 0 :(得分:0)
使用index_select()
要求索引值在向量而不是张量中。但只要格式正确,该函数就会为您处理广播。必须要做的最后一件事是重塑输出,我相信由于广播。
将成功执行此操作的单行
torch.index_select(B, 0, A.view(-1)).view(3,-1,2).permute(2,0,1)
A.view(-1)
对索引矩阵进行矢量化。
__.view(3,-1,2)
重新形成索引矩阵的形状,但考虑了大小为2的新额外维度(因为我正在索引N x 2矩阵)。
最后,__.permute(2,0,1)
重新整形矩阵,以便输出在单独的通道(而不是每列)中查看B
的每个维度。