Tensorflow - 使用自定义比较器对张量进行排序

时间:2018-05-03 11:46:12

标签: python sorting tensorflow

如何根据仅使用Tensorflow操作的自定义比较函数对形状[n,2]的Tensorflow张量进行排序?

让我们说张量中的两个条目是[x1,y1]和[x2,y2]。我想对张量进行排序,使条目按条件x1 * y2>重新排序。 x2 * y1。

2 个答案:

答案 0 :(得分:2)

假设您可以为元素创建一个指标(如果没有,请参见下面的一般情况)(此处,将不等式重新排列为 x1 / y1> x2 / y2 ,因此指标将 x / y 并依靠TensorFlow生成 inf (如无限制)为了除以零),像这段代码一样使用tf.nn.top_k()(已测试):

import tensorflow as tf

x = tf.constant( [ [1,2], [3,4], [1,3], [2,5] ] ) # some example numbers

s = tf.truediv( x[ ..., 0 ], x[ ..., 1 ] ) # your sort condition
val, idx = tf.nn.top_k( s, x.get_shape()[ 0 ].value )
x_sorted = tf.gather( x, idx )

with tf.Session() as sess:
    print( sess.run( x_sorted ) )

输出:

  

[[3 4]
   [1 2]
   [2 5]
   [1 3]]

如果您不能或不能轻松创建指标,那么仍然假设关系为您提供well-ordering。 (否则结果是未定义的。)在这种情况下,您为整个集合构建比较矩阵,并按行总和对元素进行排序(即,有多少其他元素更大);这当然是要排序的元素数量的二次方。此代码(已测试):

import tensorflow as tf

x = tf.constant( [ [1,2], [3,4], [1,3], [2,5] ] ) # some example numbers

x1, y1 = x[ ..., 0 ][ None, ... ], x[ ..., 1 ][ None, ... ] # expanding dims into cols
x2, y2 = x[ ..., 0, None ],        x[ ..., 1, None ] # expanding into rows
r = tf.cast( tf.less( x1 * y2, x2 * y1 ), tf.int32 ) # your sort condition, with implicit broadcasting
s = tf.reduce_sum( r, axis = 1 ) # how many other elements are greater

val, idx = tf.nn.top_k( s, s.get_shape()[ 0 ].value )
x_sorted = tf.gather( x, idx )

with tf.Session() as sess:
    print( sess.run( x_sorted ) )

输出:

  

[[3 4]
   [1 2]
   [2 5]
   [1 3]]

答案 1 :(得分:0)

作为top_k Peter Szoldan答案的替代方法,在1.13之后有一个tf.argsort

对于<1.13,请使用

tf.nn.top_k( s, tf.shape(x)[0] )

如果无法在静态图形中获得形状。