尝试使用PYMC3和NetworkX进行极其基本的(3个分类变量)贝叶斯推理

时间:2018-04-26 04:47:14

标签: python networkx bayesian pymc3 bayesian-networks

我试图理解贝叶斯网络的this example。我把它弄得更笨,以至于它只看三个变量:D1,D2和D3。每个都是分类的,其概率表在下面的代码顶部给出。我想设置D3 = 0,然后计算D1和D2的后验概率,就像在this page底部完成的更简单的版本一样。我尝试通过使用第一个来源的代码来尝试这样做,但是没有成功,我不理解错误消息。

对此的任何帮助都将非常感激 - 我一直在努力实施贝叶斯推理。我试过看PYMC3 Categorical documentation,但它非常简陋。并且example of inference I could find使用连续变量,并且似乎做了与我尝试做的不同的事情。或者如果它不是,我就不够聪明地建立联系并使用他们为了满足我的需要而展示的任何东西。

我不确定是否批准发布大部分代码?但我不知道怎么做才能做到这一点。这是我的代码(第一个源代码中更短,更简单的代码版本):

import networkx as nx
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pymc3 as pm
import theano
import theano.tensor as T
from theano.compile.ops import as_op

d1_prob = np.array([0.3,0.7])  # 2 choices
d2_prob = np.array([0.6,0.3,0.1])  # 3 choices
d3_prob = np.array([[[0.1, 0.9],  # (2x3)x2 choices
                 [0.3, 0.7], 
                 [0.4, 0.6]], 
                [[0.6, 0.4], 
                 [0.8, 0.2],
                 [0.9, 0.1]]])

BN = nx.DiGraph()
BN.add_node('D1', dtype='Discrete', prob=d1_prob)
BN.add_node('D2', dtype='Discrete', prob=d2_prob)
BN.add_node('D3', dtype='Discrete', prob = d3_prob, observe=np.array([0.]))
BN.add_edges_from([('D1', 'D3'), ('D2', 'D3')])

#print(BN.nodes(data=True))
#print(BN.pred['D3'])

def gpm(BN, node, num=0):
    return BN.node[BN.predecessors(node)[num]]['dist_obj']

with pm.Model() as mod2:

BN.node['D1']['dist_obj'] = pm.Categorical('D1', p=BN.node['D1']['prob'])
BN.node['D2']['dist_obj'] = pm.Categorical('D2', p=BN.node['D2']['prob'])
BN.node['D3']['dist_obj'] = pm.Categorical('D3', p=BN.node['D3']['prob'][
    gpm(BN,'D3', num=1),
    gpm(BN,'D3', num=0)
], observed=BN.node['D3']['observe'])

with mod2:
trace = pm.sample(10000)

pm.summary(trace, varnames=['D3'], start=1000)
pm.traceplot(trace[1000:], varnames=['D3'])

1 个答案:

答案 0 :(得分:0)

抱歉,我对PyMC3无法帮助你。但也许你只需要数字。

实际上我不明白为什么你需要一个推理算法。

概率表已完全指定,没有丢失的数据,因此您可以在此处应用贝叶斯规则。不可否认,即使是这么简单的例子,我也不想用铅笔和纸做这件事。所以我在这里使用了基于java的GUI工具samiam来为我使用贝叶斯规则。

什么都没有观察到: enter image description here

解释您的代码jobAdress.getAddressLine(0) gpm(),您会发现d3 = 1.然后CPT值会更改为:

enter image description here

(state0值是任意的,samiam只分配默认标签stateX)。 CPT中的行位置非常重要。