火炬加权的MSE损失

时间:2019-07-12 09:49:45

标签: python pytorch

def weighted_mse_loss(input_tensor, target_tensor, weight = 1):
    observation_dim = input_tensor.size()[-1]
    streched_tensor = ((input_tensor - target_tensor) ** 2).view(-1, observation_dim)
    entry_num = float(streched_tensor.size())[0]
    non_zero_entry_num = torch.sum(streched_tensor[:,0] != 0).float()
    weighted_tensor = torch.mm(
        ((input_tensor - target_tensor)**2).view(-1, observation_dim),
        (torch.diag(weight.float().view(-1)))
    )
    return torch.mean(weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num

我无法理解代码如何给出加权均方误差损失。 我得到observation_dim是最终的输出尺寸(我猜是类号),在那一行之后,我不明白。有人可以帮我弄清楚代码如何计算损失吗?

非常感谢。

1 个答案:

答案 0 :(得分:0)

    def weighted_mse_loss(input, target, weight):
        return (weight * (input - target) ** 2).mean()

尝试一下,希望对您有所帮助。 所有参数都需要张量。