多元高斯分布张量流概率的混合

时间:2019-12-13 00:23:05

标签: python tensorflow2.0 mixture-model tensorflow-probability

正如标题中所述,我正在尝试使用张量流概率包来创建多元正态分布的混合。

在我的原始项目中,我从神经网络的输出中获取类别,位置和方差的权重。但是,在创建图形时,出现以下错误:

  

components [0]批处理形状必须与cat形状和其他组件批处理形状兼容

我使用占位符重新创建了相同的问题:

import tensorflow as tf
import tensorflow_probability as tfp # dist= tfp.distributions 

tf.compat.v1.disable_eager_execution()
sess = tf.compat.v1.InteractiveSession()

l1 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_1')
l2 = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None, 2], name='observations_2')

log_std = tf.compat.v1.get_variable('log_std', [1, 2], dtype=tf.float32,
                                          initializer=tf.constant_initializer(1.0),
                                          trainable=True)

mix = tf.compat.v1.placeholder(dtype=tf.float32, shape=[None,1], name='weights')

cat = tfp.distributions.Categorical(probs=[mix, 1.-mix])
components = [
    tfp.distributions.MultivariateNormalDiag(loc=l1, scale_diag=tf.exp(log_std)),
    tfp.distributions.MultivariateNormalDiag(loc=l2, scale_diag=tf.exp(log_std)),
]

bimix_gauss = tfp.distributions.Mixture(
  cat=cat,
  components=components)

所以,我的问题是,我做错了什么?我调查了错误,似乎是tensorshape_util.is_compatible_with引发了错误,但我不明白为什么。

谢谢!

2 个答案:

答案 0 :(得分:0)

似乎您为tfp.distributions.Categorical提供了错误的输入。参数probs的形状应为[batch_size, cat_size],而您提供的参数应为[cat_size, batch_size, 1]。因此,也许尝试使用probs参数化tf.concat([mix, 1-mix], 1)

您的log_std可能也有问题,其形状与l1l2不同。如果MultivariateNormalDiag不能正确广播,请尝试将其形状指定为(None, 2)或对其进行平铺,以使其第一维与您的位置参数相对应。

答案 1 :(得分:0)

当组件是相同类型时,MixtureSameFamily应该性能更高。

仅传递一个分类实例(带有.batch_shape [b1,b2,...,bn])和单个MVNDiag实例(带有.batch_shape [b1,b2,...,bn,numcats])

对于只有两节课,我想知道伯努利是否会上课?