我目前正在按顺序设置中的变体自动编码器的变体,其中的任务是拟合/恢复一系列实值观测数据(因此这是一个回归问题)。
我已经使用tf.keras
构建了我的模型,并启用了急切执行和tensorflow_probability(tfp)。遵循VAE概念,生成网络发出观测数据的分布参数,我将其建模为多元正态。因此,输出是预测分布的均值和logvar。
关于训练过程,损失的第一部分是重建误差。给定生成网的预测(参数)分布,这就是真实观测的对数似然。在这里,我使用tfp.distributions
,因为它既方便又快捷。
但是,训练结束后,损失值相当低,结果表明我的模型似乎什么都不学。该模型的预测值在整个时间维度上几乎是平坦的(请记住,问题是连续的)。
尽管如此,出于完整性检查的考虑,当我用MSE损失替换对数可能性(在处理VAE时,这是没有道理的)时,它会产生非常好的数据拟合。因此,我得出结论,该对数似然项一定存在问题。有没有人对此有一些线索和/或解决方案?
我已经考虑过用交叉熵损失替换对数似然,但是我认为这不适用于我的情况,因为我的问题是回归,并且数据无法归一化为[0,1]范围。
当使用对数似然作为重建损失时,我还尝试实施退火的KL项(即,以恒定<1加权KL项)。但这也不起作用。
这是我原来的代码片段(使用对数似然作为重建误差)损失函数:
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
import tensorflow_probability as tfp
tfd = tfp.distributions
def loss(model, inputs):
outputs, _ = SSM_model(model, inputs)
#allocate the corresponding output component
infer_mean = outputs[:,:,:latent_dim] #mean of latent variable from inference net
infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)] #mean of latent variable from transition net
trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + output_obs_dim)] #mean of observation from generative net
obs_logvar = outputs[:,:,((4 * latent_dim) + output_obs_dim):]
target = inputs[:,:,2:4]
#transform logvar to std
infer_std = tf.sqrt(tf.exp(infer_logvar))
trans_std = tf.sqrt(tf.exp(trans_logvar))
obs_std = tf.sqrt(tf.exp(obs_logvar))
#computing loss at each time step
time_step_loss = []
for i in range(tf.shape(outputs)[0].numpy()):
#distribution of each module
infer_dist = tfd.MultivariateNormalDiag(infer_mean[i],infer_std[i])
trans_dist = tfd.MultivariateNormalDiag(trans_mean[i],trans_std[i])
obs_dist = tfd.MultivariateNormalDiag(obs_mean[i],obs_std[i])
#log likelihood of observation
likelihood = obs_dist.prob(target[i]) #shape = 1D = batch_size
likelihood = tf.clip_by_value(likelihood, 1e-37, 1)
log_likelihood = tf.log(likelihood)
#KL of (q|p)
kl = tfd.kl_divergence(infer_dist, trans_dist) #shape = batch_size
#the loss
loss = - log_likelihood + kl
time_step_loss.append(loss)
time_step_loss = tf.convert_to_tensor(time_step_loss)
overall_loss = tf.reduce_sum(time_step_loss)
overall_loss = tf.cast(overall_loss, dtype='float32')
return overall_loss