TensorFlow基于布尔掩码选择条目(来自两个张量之一)

时间:2016-10-07 07:41:17

标签: tensorflow

我有三个张量,abmask,形状相同。我想生成一个新的张量c,这样c的每个条目都取自a的相应条目,如果mask的相应条目为True;否则,它取自b的相应条目。

示例:

a = [0, 1, 2]
b = [10, 20, 30]
mask = [True, False, True]
c = [0, 20, 2]

我该怎么做?

2 个答案:

答案 0 :(得分:4)

为什么不使用tf.select(condition, t, e, name=None)

为您的例子:

c = tf.select(mask, a, b)

有关tf.select的更多详情,请访问Tensorflow Control Flow Documentation

答案 1 :(得分:1)

你可以这样做:

1) convert mask to ints (0 for false, 1 for true)
2) do element wise multiplication of int_mask with tensor 'a' 
    (elements that should not be included are going to be 0)
3) do logical_not on mask
4) convert logical_not_int_mask to ints 
   (again 0 for false, 1 for true values)
5) now just do element wise multiplication of logical_not_int_mask with tensor 'b' 
   (elements that should not be included are going to be 0)
6) Add tensors 'a' and 'b' together and there you have it.

在代码中它应该看起来像这样:

# tensor 'a' is [0, 1, 2]
# tensor 'b' is [10, 20, 30]
# tensor 'mask' is [True, False, True]

int_mask = tf.cast(mask, tf.int32)
# Leave only important elements in 'a'
a = tf.mul(a, int_mask) 
mask = tf.logical_not(mask)
int_mask = tf.cast(mask, tf.int32)
b = tf.mul(b, int_mask)
result = tf.add(a, b)

或者只是像已经提到的那样使用tf.select()函数。