根据另一个张量(tensorflow)的计数创建张量

时间:2018-02-11 03:32:42

标签: python tensorflow

我想对一维张量进行排序,以便根据张量中项目的频率对张量进行排序,例如给定张量[6,0,1,0,2,6,6],I想要像这样排序[6,6,6,0,0,1,2],到目前为止我到目前为止:

ks = tf.constant([6,0,1,0,2,6,6])
unique_s, idx, cnts = tf.unique_with_counts(ks)
r = tf.gather(unique_s,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)
s = tf.gather(cnts,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)

其中r包含值[6,0,1,2],s包含[3,2,1,1]。现在,我想根据s中的计数扩展r。所以,在Python中,我们可以像这样制作上面的列表:

sorted_arr = []
for i,_s in enumerate(s):
    sorted_arr.expand([r[i]]*_s)

但由于张量流中不允许迭代张量,我现在有点卡住了。

1 个答案:

答案 0 :(得分:0)

我很确定有一个更优雅的解决方案,但这里有:

这需要运行两个会话:

  1. 查找唯一值的数量
  2. 分裂,倍增&结合
  3. 代码:

    import tensorflow as tf
    ks = tf.constant([6,0,1,0,2,6,6])
    unique_s, idx, cnts = tf.unique_with_counts(ks)
    r = tf.gather(unique_s,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)
    # Dynamic length of r
    len_r = tf.shape(r)
    s = tf.gather(cnts,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)
    
    # Evaluate number of unique values
    with tf.Session() as sess:
        splits = len_r.eval()[0]
    
    # Split r & s into a list of tensors
    split_r = tf.split(r, splits)
    split_s = tf.split(s, splits)
    # Tile r, s times
    mult = [tf.tile(each_r,each_s) for each_r, each_s in zip(split_r, split_s)]
    # Concatenate
    join = tf.concat(mult, axis=0)
    
    with tf.Session() as sess:
        print(join.eval())
    

    输出[6 6 6 0 0 1 2]