我有两个多维张量a
和b
。我希望按a
。
我发现tf.nn.top_k
能够对张量进行排序并返回用于对输入进行排序的索引。如何使用tf.nn.top_k(a, k=2)
中返回的索引对b
进行排序?
例如,
import tensorflow as tf
a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2
sorted_a, indices = tf.nn.top_k(a, k)
# How to sort b into
# sorted_b[0, 0, 0, :] = b[0, 0, indices[0, 0, 0], :]
# sorted_b[0, 0, 1, :] = b[0, 0, indices[0, 0, 1], :]
# sorted_b[0, 1, 0, :] = b[0, 1, indices[0, 1, 0], :]
# ...
将tf.gather_nd
与tf.meshgrid
结合起来可以是一种解决方案。例如,以下代码在python 3.5上使用张量流1.0.0-rc0
进行测试:
a = tf.reshape(tf.range(30), (2, 5, 3))
b = tf.reshape(tf.range(210), (2, 5, 3, 7))
k = 2
sorted_a, indices = tf.nn.top_k(a, k)
shape_a = tf.shape(a)
auxiliary_indices = tf.meshgrid(*[tf.range(d) for d in (tf.unstack(shape_a[:(a.get_shape().ndims - 1)]) + [k])], indexing='ij')
sorted_b = tf.gather_nd(b, tf.stack(auxiliary_indices[:-1] + [indices], axis=-1))
但是,我想知道是否有一个更具可读性且无需在上面创建auxiliary_indices
的解决方案。
答案 0 :(得分:0)
您的代码有问题。
b = tf.reshape(tf.range(60), (2, 5, 3, 7))
因为TensorFlow无法重塑具有60个元素的张量以形成[2,5,3,7](210个元素)。 并且你不能使用3级张量的指数对等级4张量(b)进行排序。