Tensorflow,在给定条件的情况下更改Tensor值

时间:2018-11-08 17:43:25

标签: python tensorflow

我正在将numpy代码转换为Tensorflow。

它具有以下行:

netout[..., 5:] *= netout[..., 5:] > obj_threshold

这与Tensorflow语法不同,我很难找到具有相同行为的函数。

首先我尝试:

netout[..., 5:] * netout[..., 5:] > obj_threshold

但是返回的只是布尔型的Tensor。 在这种情况下,我希望obj_threshold以下的所有值都为0。

1 个答案:

答案 0 :(得分:1)

如果您只想将obj_threshold以下的所有值设为0,则可以执行以下操作:

netout = tf.where(netout > obj_threshold, netout, tf.zeros_like(netout))

或者:

netout = netout * tf.cast(netout > obj_threshold, netout.dtype)

但是,您的情况要复杂一些,因为您只希望更改影响到部分张量。因此,您可以做的是为True上的值或最后一个索引低于5的值制作一个obj_threshold的布尔掩码。

mask = (netout > obj_threshold) | (tf.range(tf.shape(netout)[-1]) < 5)

然后您可以将其与任何以前的方法一起使用:

netout = tf.where(mask, netout, tf.zeros_like(netout))

netout = netout * tf.cast(mask, netout.dtype)