在PyTorch中索引多维张量中的最大元素

时间:2019-01-05 23:02:55

标签: python multidimensional-array deep-learning pytorch tensor

我正在尝试在多维张量中沿最后一个维度索引最大元素。例如,说我有张量

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

此处idx存储最大索引,该索引可能类似于

>>>> A
tensor([[[ 1.0503,  0.4448,  1.8663],
     [ 0.8627,  0.0685,  1.4241]],

    [[ 1.2924,  0.2456,  0.1764],
     [ 1.3777,  0.9401,  1.4637]],

    [[ 0.5235,  0.4550,  0.2476],
     [ 0.7823,  0.3004,  0.7792]],

    [[ 1.9384,  0.3291,  0.7914],
     [ 0.5211,  0.1320,  0.6330]],

    [[ 0.3292,  0.9086,  0.0078],
     [ 1.3612,  0.0610,  0.4023]]])
>>>> idx
tensor([[ 2,  2],
    [ 0,  2],
    [ 0,  0],
    [ 0,  2],
    [ 1,  0]])

我希望能够访问这些索引并基于它们分配给另一个张量。表示我希望能够做到

B = torch.new_zeros(A.size())
B[idx] = A[idx]

其中B始终为0,除了A沿最后一个维度最大。那是B应该存储的

>>>>B
tensor([[[ 0,  0,  1.8663],
     [ 0,  0,  1.4241]],

    [[ 1.2924,  0,  0],
     [ 0,  0,  1.4637]],

    [[ 0.5235,  0,  0],
     [ 0.7823,  0,  0]],

    [[ 1.9384,  0,  0],
     [ 0,  0,  0.6330]],

    [[ 0,  0.9086,  0],
     [ 1.3612,  0,  0]]])

事实证明,这比我预期的要困难得多,因为idx无法正确索引数组A。到目前为止,我一直无法找到使用idx索引A的向量化解决方案。

是否有一种很好的矢量化方法?

3 个答案:

答案 0 :(得分:1)

一个丑陋的解决方法是从idx中创建一个二进制掩码,并使用它对数组进行索引。基本代码如下:

import torch
torch.manual_seed(0)

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)

诀窍是torch.arange(A.size(2))枚举idx中可能的值,而mask在等于​​idx的地方不为零。备注:

  1. 如果您确实舍弃了torch.max的第一个输出,则可以改用torch.argmax
  2. 我认为这只是一些更广泛问题的最小示例,但是请注意,您目前正在重编(1, 1, 3)内核的torch.nn.functional.max_pool3d
  3. 另外,请注意,使用遮罩分配对张量进行就地修改可能会导致自动缩放问题,因此您可能需要使用torch.where,如here所示。

我希望有人会提出一个更干净的解决方案(避免mask数组的中间分配),可能使用torch.index_select,但我现在无法使其正常工作

答案 1 :(得分:1)

您可以使用torch.meshgrid创建索引元组:

>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]

请注意,您也可以通过以下方式模仿meshgrid :(对于3D的特定情况):

>>> index_tuple = (
...     torch.arange(A.size(0))[:, None],
...     torch.arange(A.size(1))[None, :],
...     idx
... )

更多说明:
  我们将获得如下索引:

In [173]: idx 
Out[173]: 
tensor([[2, 1],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 2]])

由此,我们想转到三个索引(由于张量是3D的,因此需要三个数字来检索每个元素)。基本上,我们希望在前两个维度中构建一个网格,如下所示。 (这就是为什么我们使用meshgrid的原因。)

In [174]: A[0, 0, 2], A[0, 1, 1]  
Out[174]: (tensor(0.6288), tensor(-0.3070))

In [175]: A[1, 0, 2], A[1, 1, 0]  
Out[175]: (tensor(1.7085), tensor(0.7818))

In [176]: A[2, 0, 2], A[2, 1, 1]  
Out[176]: (tensor(0.4823), tensor(1.1199))

In [177]: A[3, 0, 2], A[3, 1, 2]    
Out[177]: (tensor(1.6903), tensor(1.0800))

In [178]: A[4, 0, 2], A[4, 1, 2]          
Out[178]: (tensor(0.9138), tensor(0.1779))

在上述5行中,索引中的前两个数字基本上是我们使用meshgrid构建的网格,而第三个数字来自idx

即前两个数字组成一个网格。

 (0, 0) (0, 1)
 (1, 0) (1, 1)
 (2, 0) (2, 1)
 (3, 0) (3, 1)
 (4, 0) (4, 1)

答案 2 :(得分:0)

可以使用 torch.scatter enter image description here

y
--------------------
8003372602728