计算均方根值时Pytorch面罩缺少值

时间:2019-01-18 07:55:01

标签: pytorch

我正在尝试计算两个火炬张量的均方根误差。我想忽略/屏蔽标签为0(缺少值)的行。我如何修改此行以考虑该限制?

torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()

谢谢。

1 个答案:

答案 0 :(得分:1)

这可以通过定义自定义的MSE损失函数*来解决,该函数从输入和目标张量中掩盖缺失值(在您的情况下为0):

def mse_loss_with_nans(input, target):

    # Missing data are nan's
    # mask = torch.isnan(target)

    # Missing data are 0's
    mask = target == 0

    out = (input[~mask]-target[~mask])**2
    loss = out.mean()

    return loss

(*)从优化角度来看,计算MSE等效于RMSE,其优点是计算速度更快。