Tensorflow:基于另一个数组的重复元素对数组进行屏蔽

时间:2018-08-24 06:33:08

标签: python tensorflow unique mask

我有一个数组x=[2, 3, 4, 3, 2],其中包含模型的状态,另一个数组给出这些状态的相应概率,prob=[.2, .1, .4, .1, .2]。但是有些状态是重复的,我需要对它们对应的概率求和。所以我想要的输出是:unique_elems=[2, 3, 4]reduced_prob=[.2+.2, .1+.1, .4]。这是我的方法:

x = tf.constant([2, 3, 4, 3, 2])
prob = tf.constant([.2, .1, .4, .1, .2])
unique_elems, _ = tf.unique(x)  # [2, 3, 4]
unique_elems = tf.expand_dims(unique_elems, axis=1) # [[2], [3], [4]]

tiled_prob = tf.tile(tf.expand_dims(prob, axis=0), [3, 1]) 
# [[0.2, 0.1, 0.4, 0.1, 0.2],
#  [0.2, 0.1, 0.4, 0.1, 0.2],
#  [0.2, 0.1, 0.4, 0.1, 0.2]]

equal = tf.equal(x, unique_elems)
# [[ True, False, False, False,  True],
#  [False,  True, False,  True, False],
#  [False, False,  True, False, False]]

reduced_prob = tf.multiply(tiled_prob, tf.cast(equal, tf.float32))
# [[0.2, 0. , 0. , 0. , 0.2],
#  [0. , 0.1, 0. , 0.1, 0. ],
#  [0. , 0. , 0.4, 0. , 0. ]]

reduced_prob = tf.reduce_sum(reduced_prob, axis=1)
# [0.4, 0.2, 0.4]

但是我想知道是否有更有效的方法来做到这一点。特别是我正在使用切片操作,我认为这对于大型阵列不是很有效。

1 个答案:

答案 0 :(得分:0)

tf.unsorted_segment_sum可以两行完成:

unique_elems, idx = tf.unique(x)  # [2, 3, 4]
reduced_prob = tf.unsorted_segment_sum(prob, idx, tf.size(unique_elems))