我正在尝试计算两个火炬张量的均方根误差。我想忽略/屏蔽标签为0(缺少值)的行。我如何修改此行以考虑该限制?
torch.sqrt(((preds.detach() - labels) ** 2).mean()).item()
谢谢。
答案 0 :(得分:1)
这可以通过定义自定义的MSE损失函数*来解决,该函数从输入和目标张量中掩盖缺失值(在您的情况下为0):
def mse_loss_with_nans(input, target):
# Missing data are nan's
# mask = torch.isnan(target)
# Missing data are 0's
mask = target == 0
out = (input[~mask]-target[~mask])**2
loss = out.mean()
return loss
(*)从优化角度来看,计算MSE等效于RMSE,其优点是计算速度更快。