我想对一维张量进行排序,以便根据张量中项目的频率对张量进行排序,例如给定张量[6,0,1,0,2,6,6],I想要像这样排序[6,6,6,0,0,1,2],到目前为止我到目前为止:
ks = tf.constant([6,0,1,0,2,6,6])
unique_s, idx, cnts = tf.unique_with_counts(ks)
r = tf.gather(unique_s,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)
s = tf.gather(cnts,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)
其中r包含值[6,0,1,2],s包含[3,2,1,1]。现在,我想根据s中的计数扩展r。所以,在Python中,我们可以像这样制作上面的列表:
sorted_arr = []
for i,_s in enumerate(s):
sorted_arr.expand([r[i]]*_s)
但由于张量流中不允许迭代张量,我现在有点卡住了。
答案 0 :(得分:0)
我很确定有一个更优雅的解决方案,但这里有:
这需要运行两个会话:
代码:
import tensorflow as tf
ks = tf.constant([6,0,1,0,2,6,6])
unique_s, idx, cnts = tf.unique_with_counts(ks)
r = tf.gather(unique_s,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)
# Dynamic length of r
len_r = tf.shape(r)
s = tf.gather(cnts,tf.nn.top_k(cnts, k=(1+tf.reduce_max(idx))).indices)
# Evaluate number of unique values
with tf.Session() as sess:
splits = len_r.eval()[0]
# Split r & s into a list of tensors
split_r = tf.split(r, splits)
split_s = tf.split(s, splits)
# Tile r, s times
mult = [tf.tile(each_r,each_s) for each_r, each_s in zip(split_r, split_s)]
# Concatenate
join = tf.concat(mult, axis=0)
with tf.Session() as sess:
print(join.eval())
输出[6 6 6 0 0 1 2]