如何为VAE实施高斯混合?

时间:2019-03-24 15:22:55

标签: tensorflow distribution gaussian tensorflow-probability mixture

我觉得我真的不知道我在做什么,所以我将描述我认为我在做什么,我想做什么以及失败的地方。

给出普通的变分自动编码器:

...
net = tf.layers.dense(net, units=code_size * 2, activation=None)
mean = net[:, :code_size]
std = net[:, code_size:]
posterior = tfd.MultivariateNormalDiagWithSoftplusScale(mean, std)
net = posterior.sample()
net = tf.layers.dense(net, units=input_size, ...)
...

我想我在做什么:让神经网络找到一个“均值”和“标准差”值,并用它创建正态分布(高斯)。 从该分布中采样并用于解码器。 换句话说:学习编码的高斯分布

现在我想对混合高斯人做同样的事情。

...
net = tf.layers.dense(net, units=code_size * 2 * code_size, activation=None)

means, stds = tf.split(net, 2, axis=-1)

means = tf.split(means, code_size, axis=-1)
stds = tf.split(stds, code_size, axis=-1)

components = [tfd.MultivariateNormalDiagWithSoftplusScale(means[i], stds[i]) for i in range(code_size)]
probs = [1.0 / code_size] * code_size

gauss_mix = tfd.Mixture(cat=tfd.Categorical(probs=probs), components=components)
net = gauss_mix.sample()
net = tf.layers.dense(net, units=input_size, ...)
...

对于我来说,这似乎比较简单,但失败并出现以下错误:

  

形状()和(?,)不兼容

这似乎来自probs,它没有批处理维度(我不认为它需要这样做)。

我认为probs定义了组件之间的概率。

如果我定义的probs也具有批处理维度,则会出现以下神秘错误,我不知道它的含义:

  

维度-1796453376必须为> = 0

我通常会误解一些概念吗?

或者我需要做些什么?

0 个答案:

没有答案