使用排序索引重新排列3-D数组?

时间:2018-07-20 01:10:45

标签: python-3.x sorting for-loop vectorization pytorch

我有一个随机数为[channels = 3, height = 10, width = 10]的3D数组。

enter image description here 然后,我使用pytorch的sort命令对这些列进行了排序,并获得了索引。

enter image description here

对应的索引如下所示:

enter image description here

现在,我想使用这些索引返回到原始矩阵。我目前使用for循环来执行此操作(不考虑批次)。代码是:

import torch
torch.manual_seed(1)
ch = 3
h = 10
w = 10
inp_unf = torch.randn(ch,h,w)
inp_sort, indices  = torch.sort(inp_unf,1)
resort = torch.zeros(inp_sort.shape)

for i in range(ch):
    for j in range(inp_sort.shape[1]):
        for k in range (inp_sort.shape[2]):
            temp = inp_sort[i,j,k]
            resort[i,indices[i,j,k],k] = temp

我希望也考虑批量处理,即输入大小为[batch, channel, height, width]

1 个答案:

答案 0 :(得分:1)

使用Tensor.scatter_()

您可以使用sort()提供的索引将已排序的张量直接散射回其原始状态:

torch.zeros(ch,h,w).scatter_(dim=1, index=indices, src=inp_sort)

直觉是基于下面的上一个答案。由于scatter()gather()基本上相反,因此inp_reunf = inp_sort.gather(dim=1, index=reverse_indices)inp_reunf.scatter_(dim=1, index=indices, src=inp_sort)相同:


上一个答案

注意 :虽然正确,但由于第二次调用sort()操作,因此性能可能较低。

您需要获取排序“ 反向索引”,这可以通过“ sort() 返回的索引进行排序”来完成。

换句话说,给定x_sort, indices = x.sort(),您就有x[indices] -> x_sort;而您想要的是reverse_indices这样的x_sort[reverse_indices] -> x

这可以通过以下方式获得:_, reverse_indices = indices.sort()

import torch
torch.manual_seed(1)
ch, h, w = 3, 10, 10
inp_unf = torch.randn(ch,h,w)

inp_sort, indices = inp_unf.sort(dim=1)
_, reverse_indices = indices.sort(dim=1)
inp_reunf = inp_sort.gather(dim=1, index=reverse_indices)

print(torch.equal(inp_unf, inp_reunf))
# True
相关问题