我正在尝试创建一个tf_boolean_mask
,以按索引的值过滤来自张量的重复索引。如果该值大于重复值,则应保留该值,其他值将被丢弃。如果索引和值相同,则只能保留一个:
[Pseudocode]
for index in indices
If index is unique:
keep index = True
else:
if val[index] > val[index of all other duplicate indices]:
keep index = True
elif val[index] < val[index of any other duplicate indices]:
keep index = False
elif val[index] == val[index of any other duplicate indices]:
keep only a single one of the equal indices(doesn't matter which)
此问题的简短示例如下:
import tensorflow as tf
tf.enable_eager_execution()
index = tf.convert_to_tensor([ 10, 5, 20, 20, 30, 30])
value = tf.convert_to_tensor([ 1., 0., 2., 0., 0., 0.])
# bool_mask = [True, True, True, False, True, False]
# or [True, True, True, False, False, True]
# the index 3 is filtered because index 2 has a greater value (2 comp. to 0)
# The index 4 and 5 are identical in their respective values, that's why both
# of them can be kept, but at maximum one of them.
...
bool_mask = ?
我当前的方法成功地解决了删除具有不同值的重复项的问题,但是未能解决具有相同值的重复项的问题。不幸的是,这是一个极端的情况,出现在我的数据中:
import tensorflow as tf
y, idx = tf.unique(index)
num_segments = tf.shape(y)[0]
maximum_vals = tf.unsorted_segment_max(value, idx, num_segments)
fused_filt = tf.stack([tf.cast(y, tf.float32), maximum_vals],axis=1)
fused_orig = tf.stack([tf.cast(index, tf.float32), value], axis=1)
fused_orig_tiled = tf.tile(fused_orig, [1, tf.shape(fused_filt)[0]])
fused_orig_res = tf.reshape(fused_orig_tiled, [-1, tf.shape(fused_filt)[0], 2])
comp_1 = tf.equal(fused_orig_res, fused_filt)
comp_2 = tf.reduce_all(comp_1, -1)
comp_3 = tf.reduce_any(comp_2, -1)
# comp_3 = [True, True, True, False, True, True]
纯的tensorflow解决方案将是不错的选择,因为可以很简单地实现索引上的For循环。谢谢。