我正在训练一个用于输出 PyTorch 中正态分布的均值和西格玛值的光流网络。对于训练,我计算 NLL 损失如下,但有时我会得到 'nan' 值。此外,对于 sigma,不要在 0 和 1 之间使用负值或浮动值,我采用 sigma 通道的 ELU 并添加 2,如下所示,但我无法从 NLL 损失中找到这些“nan”值的原因。
你能帮我解决这个问题吗?
def NegativeLogLikelihoodLoss(predicted_flow, target_flow, mask):
u_component_mean = predicted_flow[:, 0, :, :]
u_component_sigma = predicted_flow[:, 2, :, :]
u_component_distribution = torch.distributions.normal.Normal(u_component_mean, u_component_sigma)
u_component_loss = -u_component_distribution.log_prob(target_flow[:, 0, :, :])
v_component_mean = predicted_flow[:, 1, :, :]
v_component_sigma = predicted_flow[:, 3, :, :]
v_component_distribution = torch.distributions.normal.Normal(v_component_mean, v_component_sigma)
v_component_loss = -v_component_distribution.log_prob(target_flow[:, 1, :, :])
return torch.mul(u_component_loss, mask[:, 0, :, :]).mean() + torch.mul(v_component_loss, mask[:, 1, :, :]).mean()
elu = torch.nn.ELU()
sigma_part = elu(flow_prediction[:, 2:4, :, :])
sigma_part = torch.add(sigma_part, 2)
flow_prediction= torch.cat((flow_prediction[:, 0:2, :, :], sigma_part), 1)