如何可视化tf.contrib.distribution.MultivariateNormalDiag?

时间:2018-05-15 21:10:38

标签: tensorflow

我构建了一个变分自动编码器,它具有潜在分布作为多元正态分布。我想存储分发并将其可视化(在tensorboard或其他软件中)。但是,当我尝试使用add_summary中的FileWriter时,我收到错误 - MultivariateNormalDiag object has no attribute 'value'。如何存储和显示高斯分布? 代码:

import tensorflow as tf
tfd = tf.contrib.distribution

input = tf.placeholder(tf.float32, [None, 9])
x = tf.layers.dense(input,200, tf.nn.relu)
x = tf.layers.dense(x,200,tf.nn.relu)
loc = tf.layers.dense(x,1)
scale = tf.layers.dense(x,1,tf.nn.softplus)
latent = tfd.MultivariateNormalDiag(loc,scale)
# getting loss and minimizing it
writer = tf.summary.FileWriter("/tmp/dist")
writer.add_summary(latent) # Error over here

1 个答案:

答案 0 :(得分:0)

您应该将tf.Summary传递给writer.add_summary()。 相反,您传递的是tfd.MultivariateNormalDiag,其中没有value字段。

tfd.MultivariateNormalDiag中提取您要写入的值,并使用该值创建new tf.Summary。然后将其传递给writer.add_summary()

详细了解here