我正在尝试在多维张量中沿最后一个维度索引最大元素。例如,说我有张量
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
此处idx存储最大索引,该索引可能类似于
>>>> A
tensor([[[ 1.0503, 0.4448, 1.8663],
[ 0.8627, 0.0685, 1.4241]],
[[ 1.2924, 0.2456, 0.1764],
[ 1.3777, 0.9401, 1.4637]],
[[ 0.5235, 0.4550, 0.2476],
[ 0.7823, 0.3004, 0.7792]],
[[ 1.9384, 0.3291, 0.7914],
[ 0.5211, 0.1320, 0.6330]],
[[ 0.3292, 0.9086, 0.0078],
[ 1.3612, 0.0610, 0.4023]]])
>>>> idx
tensor([[ 2, 2],
[ 0, 2],
[ 0, 0],
[ 0, 2],
[ 1, 0]])
我希望能够访问这些索引并基于它们分配给另一个张量。表示我希望能够做到
B = torch.new_zeros(A.size())
B[idx] = A[idx]
其中B始终为0,除了A沿最后一个维度最大。那是B应该存储的
>>>>B
tensor([[[ 0, 0, 1.8663],
[ 0, 0, 1.4241]],
[[ 1.2924, 0, 0],
[ 0, 0, 1.4637]],
[[ 0.5235, 0, 0],
[ 0.7823, 0, 0]],
[[ 1.9384, 0, 0],
[ 0, 0, 0.6330]],
[[ 0, 0.9086, 0],
[ 1.3612, 0, 0]]])
事实证明,这比我预期的要困难得多,因为idx无法正确索引数组A。到目前为止,我一直无法找到使用idx索引A的向量化解决方案。
是否有一种很好的矢量化方法?
答案 0 :(得分:1)
一个丑陋的解决方法是从idx
中创建一个二进制掩码,并使用它对数组进行索引。基本代码如下:
import torch
torch.manual_seed(0)
A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)
mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)
诀窍是torch.arange(A.size(2))
枚举idx
中可能的值,而mask
在等于idx
的地方不为零。备注:
torch.max
的第一个输出,则可以改用torch.argmax
。(1, 1, 3)
内核的torch.nn.functional.max_pool3d
。torch.where
,如here所示。我希望有人会提出一个更干净的解决方案(避免mask
数组的中间分配),可能使用torch.index_select
,但我现在无法使其正常工作
答案 1 :(得分:1)
您可以使用torch.meshgrid
创建索引元组:
>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]
请注意,您也可以通过以下方式模仿meshgrid
:(对于3D的特定情况):
>>> index_tuple = (
... torch.arange(A.size(0))[:, None],
... torch.arange(A.size(1))[None, :],
... idx
... )
更多说明:
我们将获得如下索引:
In [173]: idx
Out[173]:
tensor([[2, 1],
[2, 0],
[2, 1],
[2, 2],
[2, 2]])
由此,我们想转到三个索引(由于张量是3D的,因此需要三个数字来检索每个元素)。基本上,我们希望在前两个维度中构建一个网格,如下所示。 (这就是为什么我们使用meshgrid的原因。)
In [174]: A[0, 0, 2], A[0, 1, 1]
Out[174]: (tensor(0.6288), tensor(-0.3070))
In [175]: A[1, 0, 2], A[1, 1, 0]
Out[175]: (tensor(1.7085), tensor(0.7818))
In [176]: A[2, 0, 2], A[2, 1, 1]
Out[176]: (tensor(0.4823), tensor(1.1199))
In [177]: A[3, 0, 2], A[3, 1, 2]
Out[177]: (tensor(1.6903), tensor(1.0800))
In [178]: A[4, 0, 2], A[4, 1, 2]
Out[178]: (tensor(0.9138), tensor(0.1779))
在上述5行中,索引中的前两个数字基本上是我们使用meshgrid构建的网格,而第三个数字来自idx
。
即前两个数字组成一个网格。
(0, 0) (0, 1)
(1, 0) (1, 1)
(2, 0) (2, 1)
(3, 0) (3, 1)
(4, 0) (4, 1)
答案 2 :(得分:0)