变分自动编码器回归类型的重构损失

时间:2019-04-12 14:51:34

标签: python tensorflow keras recurrent-neural-network eager-execution

我目前正在按顺序设置中的变体自动编码器的变体,其中的任务是拟合/恢复一系列实值观测数据(因此这是一个回归问题)。

我已经使用tf.keras构建了我的模型,并启用了急切执行和tensorflow_probability(tfp)。遵循VAE概念,生成网络发出观测数据的分布参数,我将其建模为多元正态。因此,输出是预测分布的均值和logvar。

关于训练过程,损失的第一部分是重建误差。给定生成网的预测(参数)分布,这就是真实观测的对数似然。在这里,我使用tfp.distributions,因为它既方便又快捷。

但是,训练结束后,损失值相当低,结果表明我的模型似乎什么都不学。该模型的预测值在整个时间维度上几乎是平坦的(请记住,问题是连续的)。

尽管如此,出于完整性检查的考虑,当我用MSE损失替换对数可能性(在处理VAE时,这是没有道理的)时,它会产生非常好的数据拟合。因此,我得出结论,该对数似然项一定存在问题。有没有人对此有一些线索和/或解决方案?

我已经考虑过用交叉熵损失替换对数似然,但是我认为这不适用于我的情况,因为我的问题是回归,并且数据无法归一化为[0,1]范围。

当使用对数似然作为重建损失时,我还尝试实施退火的KL项(即,以恒定<1加权KL项)。但这也不起作用。

这是我原来的代码片段(使用对数似然作为重建误差)损失函数:

    import tensorflow as tf
    tfe = tf.contrib.eager
    tf.enable_eager_execution()

    import tensorflow_probability as tfp
    tfd = tfp.distributions

    def loss(model, inputs):
        outputs, _ = SSM_model(model, inputs)

        #allocate the corresponding output component
        infer_mean = outputs[:,:,:latent_dim]  #mean of latent variable from  inference net
        infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
        trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)] #mean of latent variable from transition net
        trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
        obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + output_obs_dim)] #mean of observation from  generative net
        obs_logvar = outputs[:,:,((4 * latent_dim) + output_obs_dim):]
        target = inputs[:,:,2:4]

        #transform logvar to std
        infer_std = tf.sqrt(tf.exp(infer_logvar))
        trans_std = tf.sqrt(tf.exp(trans_logvar))
        obs_std = tf.sqrt(tf.exp(obs_logvar))

        #computing loss at each time step
        time_step_loss = []
        for i in range(tf.shape(outputs)[0].numpy()):
            #distribution of each module
            infer_dist = tfd.MultivariateNormalDiag(infer_mean[i],infer_std[i])
            trans_dist = tfd.MultivariateNormalDiag(trans_mean[i],trans_std[i])
            obs_dist = tfd.MultivariateNormalDiag(obs_mean[i],obs_std[i])

            #log likelihood of observation
            likelihood = obs_dist.prob(target[i]) #shape = 1D = batch_size
            likelihood = tf.clip_by_value(likelihood, 1e-37, 1)
            log_likelihood = tf.log(likelihood)

            #KL of (q|p)
            kl = tfd.kl_divergence(infer_dist, trans_dist) #shape = batch_size

            #the loss
            loss = - log_likelihood + kl
            time_step_loss.append(loss)

        time_step_loss = tf.convert_to_tensor(time_step_loss)        
        overall_loss = tf.reduce_sum(time_step_loss)
        overall_loss = tf.cast(overall_loss, dtype='float32')

        return overall_loss

0 个答案:

没有答案