生成最佳二叉搜索树(Cormen)

时间:2017-09-11 17:09:34

标签: python algorithm binary-search-tree dynamic-programming

我正在阅读Cormen等人,算法导论(第3版)(PDF),关于最佳二叉搜索树的第15.4节,但我在实施Python中optimal_bst函数的伪代码。

以下是我尝试将最佳BST应用于:

的示例

enter image description here

让我们将e[i,j]定义为搜索包含从ij标记的键的最佳二叉搜索树的预期成本。最后,我们希望计算e[1, n],其中n是键的数量(本例中为5)。最终的递归公式是:

enter image description here

应该通过以下伪代码实现:

enter image description here

请注意,伪代码可互换地使用基于1和0的索引,而Python仅使用后者。结果我在实现伪代码时遇到了麻烦。以下是我到目前为止的情况:

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
    for i in range(n-l+1):
        j = i + l
        e[i, j] = np.inf
        w[i, j] = w[i, j-1] + p[j-1] + q[j]
        for r in range(i, j+1):
            t = e[i-1, r-1] + e[r, j] + w[i-1, j]
            if t < e[i-1, j]:
                e[i-1, j] = t
                root[i-1, j] = r

print(w)
print(e)

但是,如果我运行此功能,则会正确计算权重w,但预期的搜索值e仍会在其初始化值处“卡住”:

[[ 0.05  0.3   0.45  0.55  0.7   1.  ]
 [ 0.    0.1   0.25  0.35  0.5   0.8 ]
 [ 0.    0.    0.05  0.15  0.3   0.6 ]
 [ 0.    0.    0.    0.05  0.2   0.5 ]
 [ 0.    0.    0.    0.    0.05  0.35]
 [ 0.    0.    0.    0.    0.    0.1 ]]
[[ 0.05   inf   inf   inf   inf   inf]
 [ 0.    0.1    inf   inf   inf   inf]
 [ 0.    0.    0.05   inf   inf   inf]
 [ 0.    0.    0.    0.05   inf   inf]
 [ 0.    0.    0.    0.    0.05   inf]
 [ 0.    0.    0.    0.    0.    0.1 ]]

我期望ewroot如下:

enter image description here

到目前为止,我已经调试了几个小时但仍然卡住了。有人可以指出上面的Python代码有什么问题吗?

3 个答案:

答案 0 :(得分:0)

在我看来,你在指数中犯了一个错误。我无法按预期工作,但下面的代码应该给你一个指示我前进的地方(可能在某个地方有一个关闭):

import numpy as np

p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)

def get2(m, i, j):
    return m[i - 1, j - 1]


def set2(m, i, j, v):
    m[i - 1, j - 1] = v


def get1(m, i):
    return m[i - 1]


def set1(m, i, v):
    m[i - 1] = v


e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
    for i in range(n - l + 2):
        j = i + l - 1
        set2(e, i, j, np.inf)
        set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
        for r in range(i, j + 1):
            t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
            if t < get2(e, i, j):
                set2(e, i, j, t)
                set2(root, i, j, r)

print(w)
print(e)

结果:

[[ 0.2   0.4   0.5   0.65  0.9   0.  ]
 [ 0.    0.2   0.3   0.45  0.7   0.  ]
 [ 0.    0.    0.1   0.25  0.5   0.  ]
 [ 0.    0.    0.    0.15  0.4   0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.5   0.7   0.8   0.95  0.    0.3 ]]
[[ 0.2   0.6   0.8   1.2   1.95  0.  ]
 [ 0.    0.2   0.4   0.8   1.35  0.  ]
 [ 0.    0.    0.1   0.35  0.85  0.  ]
 [ 0.    0.    0.    0.15  0.55  0.  ]
 [ 0.    0.    0.    0.    0.25  0.  ]
 [ 0.7   1.2   1.5   2.    0.    0.3 ]]

答案 1 :(得分:0)

最后,我使用Series'DataFrameindex对象初始化自定义columnsimport numpy as np import pandas as pd P = [0.15, 0.10, 0.05, 0.10, 0.20] Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10] n = len(P) p = pd.Series(P, index=range(1, n+1)) q = pd.Series(Q) e = pd.DataFrame(np.diag(Q), index=range(1, n+2)) w = pd.DataFrame(np.diag(Q), index=range(1, n+2)) root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1)) for l in range(1, n+1): for i in range(1, n-l+2): j = i+l-1 e.set_value(i, j, np.inf) w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j]) for r in range(i, j+1): t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j) if t < e.get_value(i, j): e.set_value(i, j, t) root.set_value(i, j, r) print(e) print(w) print(root) 来强制数组具有相同的索引在伪代码中。之后,伪代码几乎可以复制粘贴:

      0     1     2     3     4     5
1  0.05  0.45  0.90  1.25  1.75  2.75
2  0.00  0.10  0.40  0.70  1.20  2.00
3  0.00  0.00  0.05  0.25  0.60  1.30
4  0.00  0.00  0.00  0.05  0.30  0.90
5  0.00  0.00  0.00  0.00  0.05  0.50
6  0.00  0.00  0.00  0.00  0.00  0.10
      0    1     2     3     4     5
1  0.05  0.3  0.45  0.55  0.70  1.00
2  0.00  0.1  0.25  0.35  0.50  0.80
3  0.00  0.0  0.05  0.15  0.30  0.60
4  0.00  0.0  0.00  0.05  0.20  0.50
5  0.00  0.0  0.00  0.00  0.05  0.35
6  0.00  0.0  0.00  0.00  0.00  0.10
     1    2    3    4    5
1  1.0  1.0  2.0  2.0  2.0
2  0.0  2.0  2.0  2.0  4.0
3  0.0  0.0  3.0  4.0  5.0
4  0.0  0.0  0.0  4.0  5.0
5  0.0  0.0  0.0  0.0  5.0

产生预期结果:

If

我仍然会对Numpy阵列的解决方案感兴趣,因为这对我来说似乎更优雅。

答案 2 :(得分:0)

库尔特, 谢谢你的帖子!您的问题是我发现的唯一可行的解​​决方案。我花了很多钱来处理索引。这是我使用numpy数组的实现。

import numpy as np
import math

def optimalBST(p,q,n):

    e = np.zeros((n+1)**2).reshape(n+1,n+1)
    w = np.zeros((n+1)**2).reshape(n+1,n+1)
    root = np.zeros((n+1)**2).reshape(n+1,n+1)

    # Initialization
    for i in range(n+1):
        e[i,i] = q[i]
        w[i,i] = q[i]
    for i in range(0,n):
        root[i,i] = i+1

    for l in range(1,n+1):
        for i in range(0, n-l+1):
            j = i+l
            min_ = math.inf
            w[i,j] = w[i,j-1] + p[j] + q[j]
            for r in range(i,j):
                t = e[i, r-1+1] + e[r+1,j] +  w[i,j]
                if t < min_:
                    min_ = t                
                    e[i, j] = t
                    root[i, j-1] = r+1

    root_pruned = np.delete(np.delete(root, n, 1), n, 0)        # Trim last col & row.

    print("------ e -------")
    print(e)
    print("------ w -------")
    print(w)
    print("----- root -----")
    print(root_pruned)

def main():

    p = [0,.15,.1,.05,.1,.2]
    q = [.05,.1,.05,.05,.05,.1]
    n = len(p)-1

    optimalBST(p,q,n)

if __name__ == '__main__':
    main()

输出:

------ e -------
[[0.05 0.45 0.9  1.25 1.75 2.75]
 [0.   0.1  0.4  0.7  1.2  2.  ]
 [0.   0.   0.05 0.25 0.6  1.3 ]
 [0.   0.   0.   0.05 0.3  0.9 ]
 [0.   0.   0.   0.   0.05 0.5 ]
 [0.   0.   0.   0.   0.   0.1 ]]
------ w -------
[[0.05 0.3  0.45 0.55 0.7  1.  ]
 [0.   0.1  0.25 0.35 0.5  0.8 ]
 [0.   0.   0.05 0.15 0.3  0.6 ]
 [0.   0.   0.   0.05 0.2  0.5 ]
 [0.   0.   0.   0.   0.05 0.35]
 [0.   0.   0.   0.   0.   0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
 [0. 2. 2. 2. 4.]
 [0. 0. 3. 4. 5.]
 [0. 0. 0. 4. 5.]
 [0. 0. 0. 0. 5.]]