为Tensorflow / Keras中的重复元素创建遮罩

时间:2020-10-20 14:47:34

标签: python tensorflow keras deep-learning

我正在尝试为个人识别任务编写自定义损失函数,该函数在多任务学习设置中与对象检测一起进行训练。过滤后的标签值的形状为(batch_size,num_boxes)。我想创建一个掩码,以便仅考虑在dim 1中重复的值以进行进一步的计算。如何在TF / Keras后端中做到这一点?

简短示例

a.split()

(基本上,我只想过滤掉重复项,并为损失函数放弃唯一标识)。

我想可以同时使用tf.unique和tf.scatter,但我不知道如何。

1 个答案:

答案 0 :(得分:1)

此代码有效:

x = tf.constant([[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]])
def mark_duplicates_1D(x):
  y, idx, count = tf.unique_with_counts(x)
  comp = tf.math.greater(count, 1)
  comp = tf.cast(comp, tf.int32)
  res = tf.gather(comp, idx)
  mult = tf.math.not_equal(x, 0)
  mult = tf.cast(mult, tf.int32)
  res *= mult
  return res
res = tf.map_fn(fn=mark_duplicates_1D, elems=x)