计算概率树中的概率

时间:2016-01-06 00:47:55

标签: python-3.x probability

我有以下字典:

dict = {1000021: [[0.6, [1000024, 1, -2]], [0.4, [1000022, 21]]],
        1000024: [[0.7, [1000022, 11, -12]], [0.3, [1000022, 2, -1]]]}

对应于以下概率树:

enter image description here

1000021开始,我现在需要计算每个可能端点的所有概率和数字列表。每当有一个带字典条目的数字时,我都需要遵循这条路径。字典可以具有随机数量的条目和随机数量的子列表。期望的输出:

[0.4, [1000022, 21]],
[0.42, [1000022, 11, -12, 1, -2]],
[0.18, [1000022, 2, -1, 1, -2]

我尝试使用递归函数执行此操作,但无济于事。任何帮助表示赞赏。

编辑:

我在第一个例子中并不清楚,因为它可能导致假设,只有子列表中的第一个元素可以有一个字典条目,而所有这些元素实际上都可以有一个。 Copperfield给出的答案适用于上面的例子,但它不适用于例如。

mydata = {1: [[.9, [2,3]], [.1, [4,5]]],
          4: [[.2, [6,7]], [.5, [8,9]], [.3, [10,11,12]]],
          5: [[.4, [13,14]], [.6, [15,16]]]}

我期望输出为:

[0.9, [2, 3]],
[0.008, [6, 7, 13, 14]],
[0.012, [6, 7, 15, 16]],
[0.02, [8, 9, 13, 14]],
[0.03, [8, 9, 15, 16]],
[0.012, [10, 11, 12, 13, 14]],
[0.018, [10, 11, 12, 15, 16]]

1 个答案:

答案 0 :(得分:1)

远离树,但这个怎么样

import copy

mydata = {1000024: [[0.7, [1000022, 11, -12]], [0.3, [1000022, 2, -1]]], 
          1000021: [[0.6, [1000024, 1, -2]], [0.4, [1000022, 21]]]}

def prob_tree(data,ini,prob=1):
    data=copy.deepcopy(data)
    val=data.pop(ini,None)
    if val:
        for lst in val:
            if lst[1][0] in data:
                extra=lst[1][1:]
                for x in data[lst[1][0]]:
                    x[1].extend(extra)
                prob_tree(data,lst[1][0],lst[0])
            else:
                print( prob*lst[0],lst[1])

prob_tree(mydata,1000021)

输出

0.42 [1000022, 11, -12, 1, -2]
0.18 [1000022, 2, -1, 1, -2]
0.4 [1000022, 21]

修改

在灵感来袭中,这里使用了一些功能风格的新版本

import itertools, functools

def partition(pred, iterable):
    'Use a predicate to partition entries into false entries and true entries'
    # partition(is_odd, range(10)) --> 0 2 4 6 8   and  1 3 5 7 9
    # Direct from the recipes in itertools documentation
    t1, t2 = itertools.tee(iterable)
    return itertools.filterfalse(pred, t1), filter(pred, t2)


def prob_tree(data,ini) -> (float,tuple):
    """Generator of all end points of the probability tree contained 
       in data, starting with ini"""
    for prob,path in data[ini]:
        no_more,more = map(tuple,partition(lambda x: x in data, path))
        if more:
            for node in itertools.product( *[prob_tree(data,x) for x in more] ):
                new_prob,new_path = functools.reduce(lambda acum,new: (acum[0]*new[0],acum[1]+new[1]),node,(prob,tuple()))
                yield new_prob, no_more + new_path
        else:
            yield prob, no_more

mydata = {1: [[.9, [2,3]], [.1, [4,5]]],
          4: [[.2, [6,7]], [.5, [8,9]], [.3, [10,11,12]]],
          5: [[.4, [13,14]], [.6, [15,16]]]
          }

mydata2 = {1: [[.8, [2,3]], [.1, [4,5]],[.05, [2,4]],[.05,[5,6]] ],
          4: [[.2, [6,7]], [.5, [8,9]], [.3, [10,11,12]]],
          5: [[.4, [13,14]], [.6, [15,16]]]
          }

mydata3 = {1: [[.8, [2,3]], [.1, [4,5]],[.05, [2,4]],[.05,[5,6]] ],
          4: [[.2, [6,7]], [.5, [8,9]], [.3, [10,11,12]]],
          5: [[.4, [13,14]], [.6, [15,16]]],
          13:[[.58,[23,32]],[.42,[42]] ], 
          16:[ [.9,[17,18]], [.1,[20,21]] ],
          }

输出

>>> for x in prob_tree(mydata,1):
    print(x)


(0.9, (2, 3))
(0.008000000000000002, (6, 7, 13, 14))
(0.012000000000000002, (6, 7, 15, 16))
(0.020000000000000004, (8, 9, 13, 14))
(0.03, (8, 9, 15, 16))
(0.012, (10, 11, 12, 13, 14))
(0.018, (10, 11, 12, 15, 16))
>>> 
>>> 
>>> for x in prob_tree(mydata2,1):
    print(x)


(0.8, (2, 3))
(0.008000000000000002, (6, 7, 13, 14))
(0.012000000000000002, (6, 7, 15, 16))
(0.020000000000000004, (8, 9, 13, 14))
(0.03, (8, 9, 15, 16))
(0.012, (10, 11, 12, 13, 14))
(0.018, (10, 11, 12, 15, 16))
(0.010000000000000002, (2, 6, 7))
(0.025, (2, 8, 9))
(0.015, (2, 10, 11, 12))
(0.020000000000000004, (6, 13, 14))
(0.03, (6, 15, 16))
>>> 
>>> 
>>> 
>>> for x in prob_tree(mydata3,1):
    print(x)


(0.8, (2, 3))
(0.004640000000000001, (6, 7, 14, 23, 32))
(0.003360000000000001, (6, 7, 14, 42))
(0.010800000000000002, (6, 7, 15, 17, 18))
(0.0012000000000000001, (6, 7, 15, 20, 21))
(0.0116, (8, 9, 14, 23, 32))
(0.008400000000000001, (8, 9, 14, 42))
(0.027000000000000003, (8, 9, 15, 17, 18))
(0.003, (8, 9, 15, 20, 21))
(0.006959999999999999, (10, 11, 12, 14, 23, 32))
(0.00504, (10, 11, 12, 14, 42))
(0.0162, (10, 11, 12, 15, 17, 18))
(0.0018, (10, 11, 12, 15, 20, 21))
(0.010000000000000002, (2, 6, 7))
(0.025, (2, 8, 9))
(0.015, (2, 10, 11, 12))
(0.0116, (6, 14, 23, 32))
(0.008400000000000001, (6, 14, 42))
(0.027000000000000003, (6, 15, 17, 18))
(0.003, (6, 15, 20, 21))
>>> 

编辑2 添加循环引用检查

def prob_tree_with_check(data,ini,visited=frozenset()):
    """Generator of all end points of the probability tree contained 
       in data, starting with ini. Check if a previously visited branch
       of the tree is visited again and raise RuntimeError in that case"""
    if ini in visited:
        raise RuntimeError("Branch allready visited: %r"%ini)
    visited = visited.union((ini,))
    for prob,path in data[ini]:
        no_more,more = map(tuple,partition(lambda x: x in data,path))
        if more:
            for node in itertools.product( *[prob_tree_with_check(data,x,visited) for x in more] ):
                new_prob,new_path = functools.reduce(lambda acum,new: (acum[0]*new[0],acum[1]+new[1]),node,(prob,tuple()))
                yield new_prob, no_more + new_path
        else:
            yield prob, no_more

mydata_bad = {1: [[.9, [2,3]], [.1, [4,5]]],
          4: [[.2, [6,7]], [.5, [8,9]], [.3, [10,11,12]]],
          5: [[.4, [13,14]], [.6, [15,16,1]]] # <-- try to go back to 1
          }

输出

>>> for x in prob_tree_with_check(mydata_bad,1):
    x


(0.9, (2, 3))
Traceback (most recent call last):
  File "<pyshell#35>", line 1, in <module>
    for x in prob_tree_with_check(mydata_bad,1):
  File "C:\Users\David\Documents\Python Scripts\stackoverflow_test.py", line 137, in prob_tree_with_check
    for node in itertools.product( *[prob_tree_with_check(data,x,visited) for x in more] ):
  File "C:\Users\David\Documents\Python Scripts\stackoverflow_test.py", line 137, in prob_tree_with_check
    for node in itertools.product( *[prob_tree_with_check(data,x,visited) for x in more] ):
  File "C:\Users\David\Documents\Python Scripts\stackoverflow_test.py", line 132, in prob_tree_with_check
    raise RuntimeError("Branch already visited: %r"%ini)
RuntimeError: Branch already visited: 1
>>>