我正在尝试实现代表复杂状态转换模型(STM)的自定义tfd.Distribution。我需要从抽象方法tfd.Distribution._sample_n的实现中返回2个具有不同维度的数组的元组。但是,当包装方法(tfd.Distribution.sample)尝试打包这些数组时,我遇到了麻烦。
STM表征存在于许多互斥状态中的种群。随着时间的流逝,人口中的个体会根据随机过程在各州之间转移。为了表示STM的实现(即样本),最终得到一个长度为T的向量,其中包含发生过渡的时间,以及一个形状为[T,M,N]的多维数组,其中T为时间步长,M为州数,N为人口总数。
到目前为止,我有:
class Foo(tfd.Distribution):
def __init__(self):
super().__init__(dtype=tf.float32,
#...other config here
)
def _sample_n(self, n, seed=None):
# Sampling algorithm here
# t.shape = [T]
# y.shape = [T, M, N]
return t, y
foo = Foo()
foo.sample()
期望的结果:调用foo.sample()
应该返回一个(tf.tensor,tf.tensor)元组,其形状分别为[T]和[T,M,N]。
实际:
ValueError: Shapes must be equal rank, but are 1 and 3
From merging shape 0 with other shapes. for 'MyEpidemic/sample/Shape/packed' (op: 'Pack') with input shapes: [11], [11,3,1000].
答案 0 :(得分:1)
看看JointDistribution*
类here;他们可以解决这个问题。
请注意,它们仅在pip install tfp-nightly
中。
tfp.distributions.JointDistribution
tfp.distributions.JointDistributionCoroutine
tfp.distributions.JointDistributionCoroutine.Root
tfp.distributions.JointDistributionNamed
tfp.distributions.JointDistributionSequential