我有一批张量
tensors = tf.constant([[1,2,3], [2, 4, 5], [1,2,7]], dtype=tf.float32)
我有一种机制可以通过tf.boolean_mask
在其中选择一些张量:
bools = [0, 0, 0]
tensors_sorted = tf.boolean_mask(tensors, mask=bools)
问题是当所有布尔值都为零时,tensors_sorted
变为空张量。
tensors_sorted_out = sess.run(tensors_sorted) # array([], shape=(0, 3), dtype=float32)
现在,在这种情况下,我希望tensors
是标量0.
。我知道tf.cond
tensors_sorted = tf.cond(tf.reduce_sum(bools)>0, lambda:tensors_sorted, lambda: 0.)
tensors_sorted_out = sess.run(tensors_sorted) # tensors_sorted_out = 0.
但是这个看起来很慢。有没有比这更快的方法了?
添加了注释*:实际上,我们可能会考虑将tenosrs_sorted
更改为具有相同形状的零张量的选项。
答案 0 :(得分:1)
我们可能会考虑将tenosrs_sorted更改为具有相同形状的零张量的选项。可能使用tf.where
:
t = tf.where(tf.equal(bools, 0), tf.zeros_like(tensors), tensors)
t.eval()
#array([[ 0., 0., 0.],
# [ 0., 0., 0.],
# [ 0., 0., 0.]], dtype=float32)