来自tfd.Distribution.sample的复杂event_shape

时间:2019-05-22 08:03:16

标签: python tensorflow tensorflow-probability

我正在尝试实现代表复杂状态转换模型(STM)的自定义tfd.Distribution。我需要从抽象方法tfd.Distribution._sample_n的实现中返回2个具有不同维度的数组的元组。但是,当包装方法(tfd.Distribution.sample)尝试打包这些数组时,我遇到了麻烦。

STM表征存在于许多互斥状态中的种群。随着时间的流逝,人口中的个体会根据随机过程在各州之间转移。为了表示STM的实现(即样本),最终得到一个长度为T的向量,其中包含发生过渡的时间,以及一个形状为[T,M,N]的多维数组,其中T为时间步长,M为州数,N为人口总数。

到目前为止,我有:

class Foo(tfd.Distribution):
    def __init__(self):
        super().__init__(dtype=tf.float32,
                         #...other config here
                        )

    def _sample_n(self, n, seed=None):
        # Sampling algorithm here
        # t.shape = [T]
        # y.shape = [T, M, N]
        return t, y

foo = Foo()
foo.sample()

期望的结果:调用foo.sample()应该返回一个(tf.tensor,tf.tensor)元组,其形状分别为[T]和[T,M,N]。

实际:

ValueError: Shapes must be equal rank, but are 1 and 3
    From merging shape 0 with other shapes. for 'MyEpidemic/sample/Shape/packed' (op: 'Pack') with input shapes: [11], [11,3,1000].

1 个答案:

答案 0 :(得分:1)

看看JointDistribution*here;他们可以解决这个问题。

请注意,它们仅在pip install tfp-nightly中。

tfp.distributions.JointDistribution
tfp.distributions.JointDistributionCoroutine
tfp.distributions.JointDistributionCoroutine.Root
tfp.distributions.JointDistributionNamed
tfp.distributions.JointDistributionSequential