PyTorch-在torch.sort之后恢复原始张量顺序的更好方法

时间:2018-09-01 11:37:04

标签: python pytorch

我想在进行torch.sort操作和对已排序的张量进行一些其他修改后恢复原始张量顺序,以便不再对张量进行排序。最好用一个示例对此进行解释:

x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices) 
# final must be equal to torch.tanh(x)

我以这种方式实现了该功能:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    for i in range(ordered.size(0)):
        z[indices[i]] = ordered[i]
    return z

有更好的方法吗?特别是可以避免循环并更有效地计算操作吗?

在我的情况下,我有一个张量为torch.Size([B, N])的张量,并通过一次调用Btorch.sort行的每一行分别进行排序。因此,我必须通过另一个循环调用original_order B次。

还有其他想法吗?

编辑1-摆脱内循环

我通过用索引简单地用索引对z进行索引来解决了部分问题:

def original_order(ordered, indices):
    z = torch.empty_like(ordered)
    z[indices] = ordered
    return z

现在,我只需要了解如何避免B维度上的外部循环。

编辑2-摆脱外部循环

def original_order(ordered, indices, batch_size):
    # produce a vector to shift indices by lenght of the vector 
    # times the batch position
    add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)


    indices = indices + add.long().view(-1,1)

    # reduce tensor to single dimension. 
    # Now the indices take in consideration the new length
    long_ordered = ordered.view(-1)
    long_indices = indices.view(-1)

    # we are in the previous case with one dimensional vector
    z = torch.zeros_like(long_ordered).float()
    z[long_indices] = long_ordered

    # reshape to get back to the correct dimension
    return z.view(batch_size, -1)

1 个答案:

答案 0 :(得分:2)

def original_order(ordered, indices):
    return ordered.gather(1, indices.argsort(1))

示例

original = torch.tensor([
    [20, 22, 24, 21],
    [12, 14, 10, 11],
    [34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))

为什么起作用

为简单起见,假设t = [30, 10, 20],省略了张量表示法。

t.sort()免费为我们提供了排序后的张量s = [10, 20, 30]和排序索引i = [1, 2, 0]i实际上是t.argsort()的输出。

i告诉我们如何从ts。 “要将t排序为s,请从t中选择元素1,然后选择2,然后选择0,”。 Argsorting i为我们提供了另一个排序索引j = [2, 0, 1],该索引告诉我们如何从i到自然数[0, 1, 2]的规范序列,从而有效地反转了排序。另一种看待它的方式是j告诉我们如何从st。 “要将s排序为t,请从s中选择元素2,然后是0,然后是1”。 Argsorting一个排序索引给我们它的“逆索引”,反之亦然。

现在我们有了逆索引,我们将其与正确的torch.gather()转储到dim中,并且对张量进行未排序。

来源

torch.gather torch.argsort

研究此问题时,我找不到确切的解决方案,所以我认为这是原始答案。