假设我有一个(batch_size, loss_dim)
的二维张量,我希望得到每个数据样本的每个损失维度的总和,这可以用tf.reduce_mean(tensor, axis=-1)
来完成。
但是,如果我的张量中有NaN值,我想在计算总和时忽略那些NaN怎么办?有谁知道怎么做?
PS。我知道我们可以使用tf.boolean_mask
来填充NaN,但如果我只是tensor = tf.boolean_mask(tensor, tf.logical_not(tf.is_nan(tensor))
,输出将被压缩到单个维度,这不是我想要的。
非常感谢你!
答案 0 :(得分:5)
您可以使用tf.where()
将tensor
中的NaN值替换为零,同时保留原始形状:
tensor = ...
# Replace all NaN values with 0.0.
tensor_without_nans = tf.where(tf.is_nan(tensor), tf.zeros_like(tensor), tensor)
sum_ignoring_nans = tf.reduce_sum(tensor_without_nans, axis=-1)