我正在实施一个 VAE,我想手动获取负对数似然(不使用现有函数)。给定的方程是equation1,我还发现它也可以表示为equation2。我已经坚持了几天了,不知道我的代码哪里错了。
def loss_loglik(y_mean, y_logvar, x):
out_1 = (x.size()[2]*x.size()[3] / 2) * np.log(2 * np.pi)
out_2 = (x.size()[2]*x.size()[3] / 2) * torch.log(y_logvar.exp())
x_diff = x - y_mean
out_3 = torch.sum(x_diff.pow(2)) / (2 * y_logvar.exp())
loss = out_1 + out_2 + out_3
三个参数的形状是 (batch_size, 1, 28, 28)。
答案 0 :(得分:0)
好像是等式。 2错了。
应该是这样的。不是我推导的,只是想和输入的匹配...所以,请验证。
我在下面修改了你的函数。
def loss_loglik(y_mean, y_logvar, x):
m, n = x.size()[2], x.size()[3]
b = x.size()[0]
y_mean = y_mean.reshape(B, -1)
y_logvar = y_loagvar.reshape(B, -1)
out_1 = (m*n / 2) * np.log(2 * np.pi)
out_2 = (1 / 2) * y_logvar.sum(dim=1)
x_diff = x - y_mean
# note sigma is inside the sum
out_3 = torch.sum(x_diff.pow(2) / (2 * y_logvar.exp()), dim=1)
loss = out_1 + out_2 + out_3
return -loss
# shape of loss will be (batchsize,)