批量生成协方差矩阵(TensorFlow)

时间:2017-10-18 16:04:34

标签: tensorflow

我想知道如何在TensorFlow中批量生成协方差矩阵。如果我们使用以下代码

dim = 3
df = 5
ds = tf.contrib.distributions
scale_sqrt = tf.random_normal([dim, dim], seed=1234)
scale = tf.matmul(scale_sqrt, tf.transpose(scale_sqrt))
sigma = ds.WishartCholesky(df=df, scale=scale).sample()

它会起作用。但是,如果我们通过添加额外的批量维度来尝试此代码的批处理版本,那么TF将引发错误。我的批处理版本如下所示:

dim = 3
df = 5
ds = tf.contrib.distributions
num_per_batch = 10
scale_sqrt = tf.random_normal([num_per_batch, dim, dim], seed=1234)
scale = tf.matmul(scale_sqrt, tf.transpose(scale_sqrt, [0,2,1]))
sigma = ds.WishartCholesky(df=df, scale=scale).sample()

请让我知道如何有效地批量采样。

1 个答案:

答案 0 :(得分:0)

对于后验性,这个bug已被修复,并且应该在tf-nightly版本中可用(当我在sigma上调用sess.run(...)时,我会返回没有错误的值。)