TensorFlow概率中批量混合物分布的概率

时间:2019-10-07 09:28:42

标签: tensorflow tensorflow-probability

TFP发行版应具有开箱即用的批处理能力。但是,我正面临批量混合物分配的问题。 这里是一个玩具示例(使用急切执行):

tfd = tfp.distributions
mix = np.array([[0.6, 0.4],[0.3, 0.7]] )
bimix_gauss = tfd.Mixture(
  cat=tfd.Categorical(probs=mix),
  components=[
    tfd.Normal(loc=[-1.0, -2.0], scale=[0.1, 0.1]),
    tfd.Normal(loc=[+1.0, +2.0], scale=[0.5, 0.5]),
])

print(bimix_gauss.sample())
print(bimix_gauss.prob(0.0))

基本上,这只是默认示例的基础:https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/Mixture

采样工作正常,但是此分布的可归类性返回错误: InvalidArgumentError: cannot compute Add as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:Add] name: Mixture/prob/add/

任何猜测,我在做什么错了?

PS。具有批处理高斯分布的同一示例工作正常。

1 个答案:

答案 0 :(得分:0)

问题是numpy默认为float64,但TFP遵循默认为float32的TF约定。因此,您的正态分布(其参数为纯python列表)将重铸为tf.Normal构造函数中的张量以float32张量最终导致类型错误。您可以通过将np数组强制为float 32来解决问题,也可以通过将混合值作为列表而不是ndarrays传递而更简单。