如何检查张量B中是否也存在张量A的元素,并创建布尔掩码

时间:2019-12-26 20:02:53

标签: python tensorflow tensor tensorflow2.0

在tensorflow 2.0中,我有两个整数张量(tf.uint8),我们称它们为A和B. 张量A的秩是任意的,而B是单维的。 我正在寻找的结果是获得布尔(tf.bool)的张量C,使得:

(例如,假设A为3级)

  • C的形状等于A的形状
  • 当且仅当B中存在A [i,j,k]时,C [i,j,k]必须为True

(i,j,k是用于澄清概念的索引)

要总结一下,我需要检查A的元素是否在B中,并创建一个遮罩(C),该掩码说明A的哪些元素在B中,哪些不在。

视觉示例(实际上不是代码,只是研究行为的视觉表示):

 A = [[1,2,3],
     [4,5,6]]

 B = [1,5]

 C = [[True, False, False],
     [False, True, False]]

1 个答案:

答案 0 :(得分:1)

您可以执行以下操作。我找不到向量化方式解决此问题的方法,因为您希望它可以在任意大小的A上使用。但是,只要B不是很长,这应该可以正常工作。

A = tf.constant([[1,2,3],[4,5,6]])

B = tf.constant([1,5])

C = tf.math.greater(tf.reduce_sum(tf.map_fn(lambda b: tf.cast(tf.math.equal(A,b), tf.int32), B), axis=0),0)