我希望能够从多个Mixture分布中获得一致的样本。即,例如,我的代码:
import tensorflow as tf
from tensorflow.contrib.distributions import Mixture, Normal, Deterministic, Categorical
import numpy as np
rate = 0.5
cat = Categorical(probs=[1-rate, rate])
f1 = Mixture(cat=cat, components=[Normal(loc=10., scale=1.), Deterministic(0.)])
f2 = Mixture(cat=cat, components=[Normal(loc=5., scale=1.), Deterministic(0.)])
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)
sess.run([cat.sample(), f1.sample(), f2.sample()])
我明白了:
[1, 10.4463625, 0.0]
这不是我想要的,因为它们是独立绘制的,如果我们看一下.sample()
方法的源代码就有意义了。
我的问题:如何绘制样本,以便首先评估Categorical
变量,并在f1
和f2
之间共享?
答案 0 :(得分:0)
现有的库代码没有非常简洁的方法来实现这一点(尽管这些内容存在内部错误)。
现在,您可以创建一个虚拟分类分布,返回相同的缓存采样张量:
import tensorflow as tf
from tensorflow.contrib.distributions import Mixture, Normal, Deterministic, Categorical
import numpy as np
class HackedCat(tf.distributions.Categorical):
def __init__(self, *args, **kwargs):
super(HackedCat, self).__init__(*args, **kwargs)
self._cached_sample = self.sample(use_cached=False)
def sample(self, *args, **kwargs):
# Use cached sample by default or when explicitly asked to
if 'use_cached' not in kwargs or kwargs['use_cached']:
return self._cached_sample
else:
if 'use_cached' in kwargs:
del kwargs['use_cached']
return super(HackedCat, self).sample(*args, **kwargs)
def main():
rate = 0.5
cat = HackedCat(probs=[1-rate, rate])
f1 = Mixture(cat=cat,
components=[Normal(loc=10., scale=1.),
Deterministic(0.)])
f2 = Mixture(cat=cat,
components=[Normal(loc=5., scale=1.),
Deterministic(0.)])
with tf.Session() as sess:
tf.global_variables_initializer().run(session=sess)
print sess.run([cat.sample(), f1.sample(), f2.sample()])
if __name__ == '__main__':
main()