参数化张量流中的混合密度网络协方差

时间:2018-04-13 06:13:56

标签: python tensorflow mixture-model

我正在尝试构建MDN以学习P(y | x),其中y和x都具有维D,其中K分量具有完全(非对角)协方差。从NN的隐藏层的输出我需要构建组件均值,权重和协方差。对于协方差,我想要一组下三角矩阵(即协方差的Cholesky因子),即[K,D,D]张量,所以我可以利用这样的事实:对于正定矩阵,你只需随身携带矩阵的一个三角形。

目前,参数化均值(locs),权重(日志)和协方差(尺度)的NN如下所示:

def neural_network(X):

  # 2 hidden layers with 15 hidden units
  net = tf.layers.dense(X, 15, activation=tf.nn.relu)
  net = tf.layers.dense(net, 15, activation=tf.nn.relu)
  locs = tf.reshape(tf.layers.dense(net, K*D, activation=None), shape=(K, D))
  logits = tf.layers.dense(net, K, activation=None)
  scales = # some function of tf.layers.dense(net, K*D*(D+1)/2, activation=None) ?

  return locs, scales, logits

问题是,对于尺度来说,将tf.layers.dense(net, K*D*(D-1)/2, activation=None)转换为K DxD下三角矩阵的张量的最有效方法是什么(对角元素取幂以确保正定性)?

1 个答案:

答案 0 :(得分:1)

TL; DR:使用tf.contrib.distributions.fill_triangular


假设X是K维度的D个元素的张量,我们将其定义为placeholder

# batch of D-dimensional inputs
X = tf.placeholder(tf.float64, [None, D])

神经网络的定义与OP一样。

# 2 hidden layers with 15 hidden units
net = tf.layers.dense(X, 15, activation=tf.nn.relu)
net = tf.layers.dense(net, 15, activation=tf.nn.relu)

多元高斯的均值只是先前隐藏层的线性密集层。输出的形状为(None, D),因此无需将尺寸乘以K并重塑形状。

# Parametrisation of the means
locs = tf.layers.dense(net, D, activation=None)

接下来,我们定义下三角协方差矩阵。关键是在另一个线性密集层的输出上使用tf.contrib.distributions.fill_triangular

# Parametrisation of the lower-triangular covariance matrix
covariance_weights = tf.layers.dense(net, D*(D+1)/2, activation=None)
lower_triangle = tf.contrib.distributions.fill_triangular(covariance_weights)

最后一件事:我们需要确保协方差矩阵是正半定的。通过将softplus激活函数应用于对角线元素,即可轻松实现。

# Diagonal elements must be positive
diag = tf.matrix_diag_part(lower_triangle)
diag_positive = tf.layers.dense(diag, D, activation=tf.nn.softplus)
covariance_matrix = lower_triangle - tf.matrix_diag(diag) + tf.matrix_diag(diag_positive)

就是这样,我们已经使用神经网络对多元正态分布进行了参数化。


奖金:可训练的多元正态分布

Tensorflow Probability软件包具有可训练的多元正态分布,并且具有较低的三角协方差矩阵:tfp.trainable_distributions.multivariate_normal_tril

它可以如下使用:

mvn = tfp.trainable_distributions.multivariate_normal_tril(net, D)

它使用与tfp.distributions.MultivariateNormalTriL相同的方法(包括meancovariancesample等)输出多元正态三角分布。

我建议您使用它而不是自己构建。