如何在 Pytorch 中手动获取负对数似然?

时间:2021-04-15 19:35:24

标签: pytorch

我正在实施一个 VAE,我想手动获取负对数似然(不使用现有函数)。给定的方程是equation1,我还发现它也可以表示为equation2。我已经坚持了几天了,不知道我的代码哪里错了。

def loss_loglik(y_mean, y_logvar, x):
    out_1 = (x.size()[2]*x.size()[3] / 2) * np.log(2 * np.pi)
    out_2 = (x.size()[2]*x.size()[3] / 2) * torch.log(y_logvar.exp())
    x_diff = x - y_mean
    out_3 = torch.sum(x_diff.pow(2)) / (2 * y_logvar.exp())
    loss = out_1 + out_2 + out_3

三个参数的形状是 (batch_size, 1, 28, 28)。

1 个答案:

答案 0 :(得分:0)

好像是等式。 2错了。

应该是这样的。不是我推导的,只是想和输入的匹配...所以,请验证。

我在下面修改了你的函数。

def loss_loglik(y_mean, y_logvar, x):
    m, n  = x.size()[2], x.size()[3]
    b = x.size()[0]
    
    y_mean = y_mean.reshape(B, -1)
    y_logvar = y_loagvar.reshape(B, -1)


    out_1 = (m*n / 2) * np.log(2 * np.pi)
    
    out_2 = (1 / 2) * y_logvar.sum(dim=1)
    
    x_diff = x - y_mean
    # note sigma is inside the sum
    out_3 = torch.sum(x_diff.pow(2) / (2 * y_logvar.exp()), dim=1)
    
    loss = out_1 + out_2 + out_3
    
    return -loss      
    # shape of loss will be (batchsize,) 
相关问题