我有一个随机数为[channels = 3, height = 10, width = 10]
的3D数组。
然后,我使用pytorch的sort命令对这些列进行了排序,并获得了索引。
对应的索引如下所示:
现在,我想使用这些索引返回到原始矩阵。我目前使用for
循环来执行此操作(不考虑批次)。代码是:
import torch
torch.manual_seed(1)
ch = 3
h = 10
w = 10
inp_unf = torch.randn(ch,h,w)
inp_sort, indices = torch.sort(inp_unf,1)
resort = torch.zeros(inp_sort.shape)
for i in range(ch):
for j in range(inp_sort.shape[1]):
for k in range (inp_sort.shape[2]):
temp = inp_sort[i,j,k]
resort[i,indices[i,j,k],k] = temp
我希望也考虑批量处理,即输入大小为[batch, channel, height, width]
。
答案 0 :(得分:1)
Tensor.scatter_()
您可以使用sort()
提供的索引将已排序的张量直接散射回其原始状态:
torch.zeros(ch,h,w).scatter_(dim=1, index=indices, src=inp_sort)
直觉是基于下面的上一个答案。由于scatter()
与gather()
基本上相反,因此inp_reunf = inp_sort.gather(dim=1, index=reverse_indices)
与inp_reunf.scatter_(dim=1, index=indices, src=inp_sort)
相同:
注意 :虽然正确,但由于第二次调用sort()
操作,因此性能可能较低。 >
您需要获取排序“ 反向索引”,这可以通过“ 对sort()
返回的索引进行排序”来完成。
换句话说,给定x_sort, indices = x.sort()
,您就有x[indices] -> x_sort
;而您想要的是reverse_indices
这样的x_sort[reverse_indices] -> x
。
这可以通过以下方式获得:_, reverse_indices = indices.sort()
。
import torch
torch.manual_seed(1)
ch, h, w = 3, 10, 10
inp_unf = torch.randn(ch,h,w)
inp_sort, indices = inp_unf.sort(dim=1)
_, reverse_indices = indices.sort(dim=1)
inp_reunf = inp_sort.gather(dim=1, index=reverse_indices)
print(torch.equal(inp_unf, inp_reunf))
# True