我想在特定维度上对张量进行排序,并返回指定每个元素的排序索引的相同维度的张量。似乎tf.nn.top_k可以返回已排序的索引但是如何将其映射回来?
input = [[10, 3, 1], [5, 6, 2], [1, 7, 10]]
_, indices = tf.nn.top_k(input, k=3, sorted=True)
indices = [[0, 1, 2], [1, 0, 2], [2, 0, 1]]
我希望得到的是
reordered = [[0, 1, 2], [1, 0, 2], [2, 1, 0]]
答案 0 :(得分:2)
假设indices
包含所有索引,即range(k)
的排列,您可以使用
tf.map_fn(tf.invert_permutation, indices)
(tf.invert_permutation
适用于一维张量,因此您需要将其包装到map_fn
以将其应用于indices
的每一行。