未排序的段argmax解决方法tensorflow

时间:2019-01-31 14:25:06

标签: python tensorflow

我正在尝试创建一个tf_boolean_mask,以按索引的值过滤来自张量的重复索引。如果该值大于重复值,则应保留该值,其他值将被丢弃。如果索引和值相同,则只能保留一个:

[Pseudocode]
for index in indices
    If index is unique:
        keep index = True
    else:
        if val[index] > val[index of all other duplicate indices]:
            keep index = True
        elif val[index] < val[index of any other duplicate indices]:
            keep index = False
        elif val[index] == val[index of any other duplicate indices]:
            keep only a single one of the equal indices(doesn't matter which)   

此问题的简短示例如下:

import tensorflow as tf
tf.enable_eager_execution()

index = tf.convert_to_tensor([  10,    5,   20,    20,    30,    30])
value = tf.convert_to_tensor([  1.,   0.,   2.,    0.,    0.,    0.])
# bool_mask =                [True, True, True, False,  True, False]
# or                         [True, True, True, False, False,  True]
# the index 3 is filtered because index 2 has a greater value (2 comp. to 0)
# The index 4 and 5 are identical in their respective values, that's why both
# of them can be kept, but at maximum one of them. 


...
bool_mask = ?

我当前的方法成功地解决了删除具有不同值的重复项的问题,但是未能解决具有相同值的重复项的问题。不幸的是,这是一个极端的情况,出现在我的数据中:

import tensorflow as tf

y, idx = tf.unique(index) 
num_segments = tf.shape(y)[0]
maximum_vals = tf.unsorted_segment_max(value, idx, num_segments)

fused_filt = tf.stack([tf.cast(y, tf.float32), maximum_vals],axis=1)
fused_orig = tf.stack([tf.cast(index, tf.float32), value], axis=1)

fused_orig_tiled = tf.tile(fused_orig, [1, tf.shape(fused_filt)[0]])
fused_orig_res = tf.reshape(fused_orig_tiled, [-1, tf.shape(fused_filt)[0], 2])

comp_1 = tf.equal(fused_orig_res, fused_filt)
comp_2 = tf.reduce_all(comp_1, -1)
comp_3 = tf.reduce_any(comp_2, -1)
# comp_3 = [True, True, True, False, True, True]

纯的tensorflow解决方案将是不错的选择,因为可以很简单地实现索引上的For循环。谢谢。

0 个答案:

没有答案