我想知道如何在TensorFlow中批量生成协方差矩阵。如果我们使用以下代码
dim = 3
df = 5
ds = tf.contrib.distributions
scale_sqrt = tf.random_normal([dim, dim], seed=1234)
scale = tf.matmul(scale_sqrt, tf.transpose(scale_sqrt))
sigma = ds.WishartCholesky(df=df, scale=scale).sample()
它会起作用。但是,如果我们通过添加额外的批量维度来尝试此代码的批处理版本,那么TF将引发错误。我的批处理版本如下所示:
dim = 3
df = 5
ds = tf.contrib.distributions
num_per_batch = 10
scale_sqrt = tf.random_normal([num_per_batch, dim, dim], seed=1234)
scale = tf.matmul(scale_sqrt, tf.transpose(scale_sqrt, [0,2,1]))
sigma = ds.WishartCholesky(df=df, scale=scale).sample()
请让我知道如何有效地批量采样。
答案 0 :(得分:0)
对于后验性,这个bug已被修复,并且应该在tf-nightly版本中可用(当我在sigma上调用sess.run(...)时,我会返回没有错误的值。)