我有以下张量
inp = tensor([[[ 0.0000e+00, 5.7100e+02, -6.9846e+00],
[ 0.0000e+00, 4.4070e+03, -7.1008e+00],
[ 0.0000e+00, 3.0300e+02, -7.2226e+00],
[ 0.0000e+00, 6.8000e+01, -7.2777e+00],
[ 1.0000e+00, 5.7100e+02, -6.9846e+00],
[ 1.0000e+00, 4.4070e+03, -7.1008e+00],
[ 1.0000e+00, 3.0300e+02, -7.2226e+00],
[ 1.0000e+00, 6.8000e+01, -7.2777e+00]],
[[ 0.0000e+00, 2.1610e+03, -7.0754e+00],
[ 0.0000e+00, 6.8000e+01, -7.2259e+00],
[ 0.0000e+00, 1.0620e+03, -7.2920e+00],
[ 0.0000e+00, 2.9330e+03, -7.3009e+00],
[ 1.0000e+00, 2.1610e+03, -7.0754e+00],
[ 1.0000e+00, 6.8000e+01, -7.2259e+00],
[ 1.0000e+00, 1.0620e+03, -7.2920e+00],
[ 1.0000e+00, 2.9330e+03, -7.3009e+00]],
[[ 0.0000e+00, 4.4070e+03, -7.1947e+00],
[ 0.0000e+00, 3.5600e+02, -7.2958e+00],
[ 0.0000e+00, 3.0300e+02, -7.3232e+00],
[ 0.0000e+00, 1.2910e+03, -7.3615e+00],
[ 1.0000e+00, 4.4070e+03, -7.1947e+00],
[ 1.0000e+00, 3.5600e+02, -7.2958e+00],
[ 1.0000e+00, 3.0300e+02, -7.3232e+00],
[ 1.0000e+00, 1.2910e+03, -7.3615e+00]]])
形状
torch.Size([3, 8, 3])
我想在 dim1 中找到 topk(k=4) 元素,其中要排序的值是 dim2(负值)。生成的张量形状应该是:
torch.Size([3, 4, 3])
我知道如何对单个张量进行 topk,但如何一次对多个批次进行此操作?
答案 0 :(得分:1)
我是这样做的:
val, ind = inp[:, :, 2].squeeze().topk(k=4, dim=1, sorted=True)
new_ind = ind.unsqueeze(-1).repeat(1,1,3)
result = inp.gather(1, new_ind)
我不知道这是否是最好的方法,但它奏效了。
答案 1 :(得分:0)
一种方法是将 fancy indexing 和 broadcasting 组合如下:
我以形状为 x
和 (3, 4, 3)
的随机张量 k
为 2 作为示例。
>>> import torch
>>> x = torch.rand(3, 4, 3)
>>> x
tensor([[[0.0256, 0.7366, 0.2528],
[0.5596, 0.9450, 0.5795],
[0.8265, 0.5469, 0.8304],
[0.4223, 0.5206, 0.2898]],
[[0.2159, 0.0369, 0.6869],
[0.4556, 0.5804, 0.3169],
[0.8194, 0.5240, 0.0055],
[0.8357, 0.4162, 0.3740]],
[[0.3849, 0.0223, 0.9951],
[0.2872, 0.5952, 0.6570],
[0.1433, 0.8450, 0.6557],
[0.0270, 0.9176, 0.3904]]])
现在沿着所需的维度(这里是最后一个)对张量进行排序并获得索引:
>>> _, idx = torch.sort(x[:, :, -1])
>>> k = 2
>>> idx = idx[:, :k]
# idx is =
tensor([[0, 3],
[2, 1],
[3, 2]])
现在生成三对索引 (i, j, k)
来对原始张量进行切片,如下所示:
>>> i = torch.arange(x.shape[0]).reshape(x.shape[0], 1, 1)
>>> j = idx.reshape(x.shape[0], -1, 1)
>>> k = torch.arange(x.shape[2]).reshape(1, 1, x.shape[2])
请注意,一旦您通过 (i, j, k)
索引任何内容,它们将转到 expand 并采用形状 (x.shape[0], k, x.shape[2])
,这是此处所需的输出形状。
现在只需通过 i、j 和 k 索引 x
:
>>> x[i, j, k]
tensor([[[0.0256, 0.7366, 0.2528],
[0.4223, 0.5206, 0.2898]],
[[0.8194, 0.5240, 0.0055],
[0.4556, 0.5804, 0.3169]],
[[0.0270, 0.9176, 0.3904],
[0.1433, 0.8450, 0.6557]]])
本质上,我遵循的一般方法是通过索引数组创建张量的相应访问模式,然后使用这些数组作为索引直接对张量进行切片。
我实际上是为了升序排序这样做的,所以在这里我得到了前 k 个最少的元素。一个简单的解决方法是使用 torch.sort(x[:, :, -1], descending = True)
。