如何按给定的分段收集张量元素?

时间:2019-06-27 17:45:20

标签: python tensorflow

我正在尝试实现“ segment_collect”(非常类似于segment_max,但是收集到张量中而不是使用max)。

t = tf.constant(["a", "b", "c", "d"])
s = tf.constant([0, 1, 1, 0])
r = tf.segment_collect(t, s)  # r == [["a", "d"], ["b", "c"]]

一个简单的实现是用下面的伪代码逐行构建结果:

r = []
for i in range(2):
    mask = tf.equal(s, i)
    values = tf.boolean_mask(t, mask)
    r.append(values)
# convert r into a tensor at last

但这不是很有效。

一个后续问题是:是否有一种通用的方法来对张量进行分组/汇总?除了张量流中的segment_ {min / max / mean / prod / sum},这将允许更多操作,例如segment_size,segment_median,segment_percentile。

1 个答案:

答案 0 :(得分:0)

您可能会发现tf.gathertf.nn.topk有帮助:

tf.gather(t, tf.nn.top_k(-s, k=tf.shape(s)[0]).indices) 

这对TF 1.x和TF 2.0均适用。如有必要,重塑结果:

tf.reshape(tf.gather(t, tf.nn.top_k(-s, k=tf.shape(s)[0]).indices), shape=(-1, 2)) 

当然,重塑假定元素被划分为大小相等的组(在这种情况下为两组)。


tf.keras.backend.eval(tf.gather(t, tf.nn.top_k(-s, k=tf.shape(s)[0]).indices))                                                           
# array([b'a', b'd', b'b', b'c'], dtype=object)