我觉得我真的不知道我在做什么,所以我将描述我认为我在做什么,我想做什么以及失败的地方。
给出普通的变分自动编码器:
...
net = tf.layers.dense(net, units=code_size * 2, activation=None)
mean = net[:, :code_size]
std = net[:, code_size:]
posterior = tfd.MultivariateNormalDiagWithSoftplusScale(mean, std)
net = posterior.sample()
net = tf.layers.dense(net, units=input_size, ...)
...
我想我在做什么:让神经网络找到一个“均值”和“标准差”值,并用它创建正态分布(高斯)。 从该分布中采样并用于解码器。 换句话说:学习编码的高斯分布
现在我想对混合高斯人做同样的事情。
...
net = tf.layers.dense(net, units=code_size * 2 * code_size, activation=None)
means, stds = tf.split(net, 2, axis=-1)
means = tf.split(means, code_size, axis=-1)
stds = tf.split(stds, code_size, axis=-1)
components = [tfd.MultivariateNormalDiagWithSoftplusScale(means[i], stds[i]) for i in range(code_size)]
probs = [1.0 / code_size] * code_size
gauss_mix = tfd.Mixture(cat=tfd.Categorical(probs=probs), components=components)
net = gauss_mix.sample()
net = tf.layers.dense(net, units=input_size, ...)
...
对于我来说,这似乎比较简单,但失败并出现以下错误:
形状()和(?,)不兼容
这似乎来自probs
,它没有批处理维度(我不认为它需要这样做)。
我认为probs
定义了组件之间的概率。
如果我定义的probs
也具有批处理维度,则会出现以下神秘错误,我不知道它的含义:
维度-1796453376必须为> = 0
我通常会误解一些概念吗?
或者我需要做些什么?