PyMC3多项式模型不适用于非整数观察数据

时间:2016-09-01 17:21:21

标签: python pymc pymc3

我正在尝试使用PyMC3来解决一个相当简单的多项分布。如果我将'noise'值设置为0.0,它将完美地工作。但是,当我将其更改为其他任何内容时,例如0.01,我在find_MAP()函数中出现错误,如果我不使用find_MAP(),它会挂起。

是否存在多项式必须稀疏的原因?

import numpy as np
from pymc3 import *
import pymc3 as mc
import pandas as pd
print 'pymc3 version: ' + mc.__version__


sample_size = 10
number_of_experiments = 1


true_probs = [0.2, 0.1, 0.3, 0.4]


k = len(true_probs)


noise = 0.0
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise
y_denominator = np.sum(y,axis=1)
y = y/y_denominator[:,None]


with Model() as multinom_test:
    probs = Dirichlet('probs', a = np.ones(k), shape = k)
    for i in range(sample_size):
        data = Multinomial('data_%d' % (i),
                           n = y[i].sum(),
                           p = probs,
                           observed = y[i])


with multinom_test:
    start = find_MAP()
    trace = sample(5000, Slice())
trace[probs].mean(0)

错误:

ValueError: Optimization error: max, logp or dlogp at max have non-
finite values. Some values may be outside of distribution support. 
max: {'probs_stickbreaking_': array([  0.00000000e+00,  -4.47034834e- 
08,   0.00000000e+00])} logp: array(-inf) dlogp: array([  
0.00000000e+00,   2.98023221e-08,   0.00000000e+00])Check that 1) you 
don't have hierarchical parameters, these will lead to points with 
infinite density. 2) your distribution logp's are properly specified. 
Specific issues:

1 个答案:

答案 0 :(得分:4)

这对我有用

sample_size = 10
number_of_experiments = 100

true_probs = [0.2, 0.1, 0.3, 0.4]
k = len(true_probs)
noise = 0.01
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise

with pm.Model() as multinom_test:
    a = pm.Dirichlet('a', a=np.ones(k))
    for i in range(sample_size):
        data_pred = pm.Multinomial('data_pred_%s'% i, n=number_of_experiments, p=a, observed=y[i])
    trace = pm.sample(50000, pm.Metropolis())
    #trace = pm.sample(1000) # also works with NUTS

pm.traceplot(trace[500:]);

traceplot