Tensorflow:按最大值过滤3D索引重复项

时间:2018-11-07 09:05:21

标签: python numpy tensorflow indexing tensor

我正在尝试创建一个过滤器蒙版,通过比较它们各自的较大值来从向量中删除重复的索引。

我目前的做法是:

  1. 将3D索引转换为1D
  2. 检查一维索引的唯一性
  3. 计算每个唯一索引的最大值
  4. 将最大值与原始值进行比较。如果存在相同的值,则保留该3-D索引。

我想获得一个过滤器数组,以便可以将boolean_mask应用于其他张量。对于此示例,掩码应如下所示: [False True True True True]

我当前的代码种类有效,除非值本身也被重复。但是,当我使用它时似乎是这种情况,因此我需要找到一种更好的解决方案。

这是我的代码外观的例证

import tensorflow as tf

# Dummy Input values with same Structure as the real
x_cells   = tf.constant([1,2,3,4,1], dtype=tf.int32)   # Index_1
y_cells   = tf.constant([4,4,4,4,4], dtype=tf.int32)   # Index_2
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32) # Index_3
iou_max   = tf.constant([1.,2.,3.,4.,5.], dtype=tf.float32) # Values

# my Output should be a mask that is [False True True True True]
# So if i filter this i get e.g. x_cells = [2,3,4,1] or iou_max = [2.,3.,4.,5.]

max_dim_y = tf.constant(10)
max_dim_x = tf.constant(20)
num_anchors = 5
stride = 32

# 1. Transforming the 3D-Index to 1D
tmp = tf.stack([x_cells, y_cells, iou_index], axis=1)
indices = tf.matmul(tmp, [[max_dim_y * num_anchors],     [num_anchors],[1]])

# 2. Looking for unique / duplicate indices
y, idx = tf.unique(tf.squeeze(indices))

# 3. Calculating the maximum values of each unique index.
# An function like unsorted_segment_argmax() would be awesome here
num_segments = tf.shape(y)[0]
ious = tf.unsorted_segment_max(iou_max, idx, num_segments)

iou_max_length = tf.shape(iou_max)[0]
ious_length = tf.shape(ious)[0]

# 4. Compare all max values to original values.
iou_max_tiled = tf.tile(iou_max, [ious_length])
iou_reshaped = tf.reshape(iou_max_tiled, [ious_length, iou_max_length])
iou_max_reshaped = tf.transpose(iou_reshaped)
filter_mask = tf.reduce_any(tf.equal(iou_max_reshaped, ious), -1)
filter_mask = tf.reshape(filter_mask, shape=[-1])

如果仅将开头的iou_max变量的值更改为:

,则上述代码将失败。

x_cells = tf.constant([1,2,3,4,1], dtype=tf.int32)
y_cells = tf.constant([4,4,4,4,4], dtype=tf.int32)
iou_index = tf.constant([1,2,3,4,1], dtype=tf.int32)
iou_max = tf.constant([2.,2.,3.,4.,5.], dtype=tf.float32)

1 个答案:

答案 0 :(得分:0)

我当前的解决方法更改了问题的第4点:

基本上,我更改了比较元组而不是单个值。这使我能够逻辑地检查索引AND值是否都在 3的其余值中。

# 4. Compare a Max Value and Indices with original values
rem_index_val_pair = tf.stack([ious, tf.cast(y, dtype=tf.float32)], axis=1)
orig_val_index_pair = tf.stack([iou_max, tf.cast(indices, dtype=tf.float32)], axis=1)

orig_val_index_pair_t = tf.tile(orig_val_index_pair, [1, ious_length])
orig_val_index_pair_s = tf.reshape(orig_val_index_pair_t, [iou_max_length, ious_length, 2])
filter_mask_1 = tf.equal(orig_val_index_pair_s, rem_index_val_pair)
filter_mask_2 = tf.reduce_all(filter_mask_1, -1)
filter_mask_3 = tf.reduce_any(filter_mask_2, -1)

# The orig_val_index_pair_s looks like the following
a =  [[[  2.  71.][  2.  71.][  2.  71.][  2.  71.]
     [[  2. 122.][  2. 122.][  2. 122.][  2. 122.]]
     [[  3. 173.][  3. 173.][  3. 173.][  3. 173.]]
     [[  4. 224.][  4. 224.][  4. 224.][  4. 224.]]
     [[  5.  71.][  5.  71.][  5.  71.][  5.  71.]]]
# I then compare it to the rem_max_val_pair which looks like this.
b =  [[  5.  71.][  2. 122.][  3. 173.][  4. 224.]]

# Using equal(a,b) will now compare each of the values resulting in:
c = [[[False  True][ True False][False False][False False]]
     [[False False][ True  True][False False][False False]]
     [[False False][False False][ True  True][False False]]
     [[False False][False False][False False][ True  True]]
     [[ True  True][False False][False False][False False]]]

# Using tf.reduce_all(c, -1) I can filter the bool pairs with a logical And. 
# (This kicks out my false positives from before).
# Afterwards I can check if the line has any true value by tf.reduce_any().

IMO,此解决方案仍然是一个肮脏的解决方法。因此,如果您有更好的解决方案建议,请分享。 :)