如何从几个混合物分布中进行一致的采样?

时间:2018-02-20 10:19:23

标签: python tensorflow

我希望能够从多个Mixture分布中获得一致的样本。即,例如,我的代码:

import tensorflow as tf
from tensorflow.contrib.distributions import Mixture, Normal, Deterministic, Categorical
import numpy as np

rate = 0.5 

cat = Categorical(probs=[1-rate, rate])

f1 = Mixture(cat=cat, components=[Normal(loc=10., scale=1.), Deterministic(0.)])
f2 = Mixture(cat=cat, components=[Normal(loc=5., scale=1.), Deterministic(0.)])

sess = tf.Session()
tf.global_variables_initializer().run(session=sess)

sess.run([cat.sample(), f1.sample(), f2.sample()])

我明白了:

[1, 10.4463625, 0.0]

这不是我想要的,因为它们是独立绘制的,如果我们看一下.sample()方法的源代码就有意义了。

我的问题:如何绘制样本,以便首先评估Categorical变量,并在f1f2之间共享?

1 个答案:

答案 0 :(得分:0)

现有的库代码没有非常简洁的方法来实现这一点(尽管这些内容存在内部错误)。

现在,您可以创建一个虚拟分类分布,返回相同的缓存采样张量:

import tensorflow as tf
from tensorflow.contrib.distributions import Mixture, Normal, Deterministic, Categorical
import numpy as np

class HackedCat(tf.distributions.Categorical):
  def __init__(self, *args, **kwargs):
    super(HackedCat, self).__init__(*args, **kwargs)
    self._cached_sample = self.sample(use_cached=False)

  def sample(self, *args, **kwargs):
    # Use cached sample by default or when explicitly asked to
    if 'use_cached' not in kwargs or kwargs['use_cached']:
      return self._cached_sample
    else:
      if 'use_cached' in kwargs:
        del kwargs['use_cached']
      return super(HackedCat, self).sample(*args, **kwargs)

def main():
  rate = 0.5 
  cat = HackedCat(probs=[1-rate, rate])
  f1 = Mixture(cat=cat,
               components=[Normal(loc=10., scale=1.),
                           Deterministic(0.)])
  f2 = Mixture(cat=cat,
               components=[Normal(loc=5., scale=1.),
                           Deterministic(0.)])

  with tf.Session() as sess:
    tf.global_variables_initializer().run(session=sess)
    print sess.run([cat.sample(), f1.sample(), f2.sample()])

if __name__ == '__main__':
  main()