如何使用返回的tf.nn.top_k索引对多维张量进行排序?

时间:2017-01-27 15:15:30

标签: tensorflow

我有两个多维张量ab。我希望按a

的值对它们进行排序

我发现tf.nn.top_k能够对张量进行排序并返回用于对输入进行排序的索引。如何使用tf.nn.top_k(a, k=2)中返回的索引对b进行排序?

例如,

import tensorflow as tf

a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2
sorted_a, indices = tf.nn.top_k(a, k)

# How to sort b into
# sorted_b[0, 0, 0, :] = b[0, 0, indices[0, 0, 0], :]
# sorted_b[0, 0, 1, :] = b[0, 0, indices[0, 0, 1], :]
# sorted_b[0, 1, 0, :] = b[0, 1, indices[0, 1, 0], :]
# ...

更新

tf.gather_ndtf.meshgrid结合起来可以是一种解决方案。例如,以下代码在python 3.5上使用张量流1.0.0-rc0进行测试:

a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2

sorted_a, indices = tf.nn.top_k(a, k)

shape_a = tf.shape(a)
auxiliary_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(shape_a[:(a.get_shape().ndims - 1)]) + [k])], indexing='ij')

sorted_b = tf.gather_nd(b, tf.stack(auxiliary_indices[:-1] + [indices], axis=-1))

但是,我想知道是否有一个更具可读性且无需在上面创建auxiliary_indices的解决方案。

1 个答案:

答案 0 :(得分:0)

您的代码有问题。

b = tf.reshape(tf.range(60), (2, 5, 3, 7))

因为TensorFlow无法重塑具有60个元素的张量以形成[2,5,3,7](210个元素)。 并且你不能使用3级张量的指数对等级4张量(b)进行排序。