我想在进行torch.sort
操作和对已排序的张量进行一些其他修改后恢复原始张量顺序,以便不再对张量进行排序。最好用一个示例对此进行解释:
x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices)
# final must be equal to torch.tanh(x)
我以这种方式实现了该功能:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
for i in range(ordered.size(0)):
z[indices[i]] = ordered[i]
return z
有更好的方法吗?特别是可以避免循环并更有效地计算操作吗?
在我的情况下,我有一个张量为torch.Size([B, N])
的张量,并通过一次调用B
对torch.sort
行的每一行分别进行排序。因此,我必须通过另一个循环调用original_order
B
次。
还有其他想法吗?
编辑1-摆脱内循环
我通过用索引简单地用索引对z进行索引来解决了部分问题:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
z[indices] = ordered
return z
现在,我只需要了解如何避免B
维度上的外部循环。
编辑2-摆脱外部循环
def original_order(ordered, indices, batch_size):
# produce a vector to shift indices by lenght of the vector
# times the batch position
add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)
indices = indices + add.long().view(-1,1)
# reduce tensor to single dimension.
# Now the indices take in consideration the new length
long_ordered = ordered.view(-1)
long_indices = indices.view(-1)
# we are in the previous case with one dimensional vector
z = torch.zeros_like(long_ordered).float()
z[long_indices] = long_ordered
# reshape to get back to the correct dimension
return z.view(batch_size, -1)
答案 0 :(得分:2)
def original_order(ordered, indices):
return ordered.gather(1, indices.argsort(1))
original = torch.tensor([
[20, 22, 24, 21],
[12, 14, 10, 11],
[34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))
为简单起见,假设t = [30, 10, 20]
,省略了张量表示法。
t.sort()
免费为我们提供了排序后的张量s = [10, 20, 30]
和排序索引i = [1, 2, 0]
。 i
实际上是t.argsort()
的输出。
i
告诉我们如何从t
到s
。 “要将t
排序为s
,请从t
中选择元素1,然后选择2,然后选择0,”。 Argsorting i
为我们提供了另一个排序索引j = [2, 0, 1]
,该索引告诉我们如何从i
到自然数[0, 1, 2]
的规范序列,从而有效地反转了排序。另一种看待它的方式是j
告诉我们如何从s
到t
。 “要将s
排序为t
,请从s
中选择元素2,然后是0,然后是1”。 Argsorting一个排序索引给我们它的“逆索引”,反之亦然。
现在我们有了逆索引,我们将其与正确的torch.gather()
转储到dim
中,并且对张量进行未排序。
研究此问题时,我找不到确切的解决方案,所以我认为这是原始答案。