我想针对子数组对张量进行排序。 例如,我有以下张量:
A = tf.constant([[4, 2, 1, 7, 5],
[10, 20, 30, 40, 50]])
我想对A [0,:]的张量A进行排序。
我期望的结果是:
A = tf.constant([[1, 2, 4, 5, 7],
[30, 20, 10, 50, 40]])
我在堆栈溢出中看到了类似的问题。 (Python, sort array with respect to a sub-array)
但是这个问题与python数组有关,答案不适用于我的问题。
有人可以帮助我吗?谢谢。
答案 0 :(得分:1)
将tf.gather
与tf.argsort
一起使用:
import tensorflow as tf:
a = tf.constant([[4, 2, 1, 7, 5],
[10, 20, 30, 40, 50]])
b = tf.gather(a, tf.argsort(a[0]), axis=1)
b
输出:
<tf.Tensor: id=152, shape=(2, 5), dtype=int32, numpy=
array([[ 1, 2, 4, 5, 7],
[30, 20, 10, 50, 40]])>
答案 1 :(得分:1)
您可以使用argsort
上的A[0, :]
函数来计算列顺序,并使用gather
函数来计算新的张量。 cf. tensorflow.org。
import tensorflow as tf
A = tf.constant([[4, 2, 1, 7, 5],
[10, 20, 30, 40, 50]])
tf.gather(A, tf.argsort(A[0, :]), axis=1)