我该如何实现一个VAE,以便在编码空间中学习多元正态全协方差分布?

时间:2019-07-09 04:11:46

标签: r tensorflow-probability

我在R中使用tfprobability创建了VAE。虽然KL散度项通常适合于正态独立分布N(0,I)的潜在空间,但我对指定完整的协方差矩阵感兴趣。

我可以在复制示例 https://www.tensorflow.org/probability/api_docs/python/tfp/layers/KLDivergenceAddLoss 但这使用了tfd.MultivariateNormalDiag,要求协方差矩阵是对角线。我已经尝试过同时使用tfd.MultivariateNormalTriL和tfd.MultivariateNormalFullCovariance作为KLDivergenceAddLoss函数的distribution_b参数,但是无论哪种方式我都会收到“ NotImplementedError”。

library(keras)
library(dplyr)
library(tensorflow)
library(tfprobability)

encoder_model <- keras_model_sequential() %>%
    layer_flatten(input_shape=28) %>%
    layer_dense(units=10, activation='relu') %>%
    layer_dense(units = params_size_multivariate_normal_tri_l(3), name='encoded') %>%
    layer_multivariate_normal_tri_l(event_size=3)  %>%
    layer_kl_divergence_add_loss(
        distribution_b = tfd_multivariate_normal_full_covariance(
            loc=c(0,0,0),
            covariance_matrix = diag(3)))

0 个答案:

没有答案