我正在阅读Cormen等人,算法导论(第3版)(PDF),关于最佳二叉搜索树的第15.4节,但我在实施Python中optimal_bst
函数的伪代码。
以下是我尝试将最佳BST应用于:
的示例让我们将e[i,j]
定义为搜索包含从i
到j
标记的键的最佳二叉搜索树的预期成本。最后,我们希望计算e[1, n]
,其中n
是键的数量(本例中为5)。最终的递归公式是:
应该通过以下伪代码实现:
请注意,伪代码可互换地使用基于1和0的索引,而Python仅使用后者。结果我在实现伪代码时遇到了麻烦。以下是我到目前为止的情况:
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n+1):
for i in range(n-l+1):
j = i + l
e[i, j] = np.inf
w[i, j] = w[i, j-1] + p[j-1] + q[j]
for r in range(i, j+1):
t = e[i-1, r-1] + e[r, j] + w[i-1, j]
if t < e[i-1, j]:
e[i-1, j] = t
root[i-1, j] = r
print(w)
print(e)
但是,如果我运行此功能,则会正确计算权重w
,但预期的搜索值e
仍会在其初始化值处“卡住”:
[[ 0.05 0.3 0.45 0.55 0.7 1. ]
[ 0. 0.1 0.25 0.35 0.5 0.8 ]
[ 0. 0. 0.05 0.15 0.3 0.6 ]
[ 0. 0. 0. 0.05 0.2 0.5 ]
[ 0. 0. 0. 0. 0.05 0.35]
[ 0. 0. 0. 0. 0. 0.1 ]]
[[ 0.05 inf inf inf inf inf]
[ 0. 0.1 inf inf inf inf]
[ 0. 0. 0.05 inf inf inf]
[ 0. 0. 0. 0.05 inf inf]
[ 0. 0. 0. 0. 0.05 inf]
[ 0. 0. 0. 0. 0. 0.1 ]]
我期望e
,w
和root
如下:
到目前为止,我已经调试了几个小时但仍然卡住了。有人可以指出上面的Python代码有什么问题吗?
答案 0 :(得分:0)
在我看来,你在指数中犯了一个错误。我无法按预期工作,但下面的代码应该给你一个指示我前进的地方(可能在某个地方有一个关闭):
import numpy as np
p = [0.15, 0.10, 0.05, 0.10, 0.20]
q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(p)
def get2(m, i, j):
return m[i - 1, j - 1]
def set2(m, i, j, v):
m[i - 1, j - 1] = v
def get1(m, i):
return m[i - 1]
def set1(m, i, v):
m[i - 1] = v
e = np.diag(q)
w = np.diag(q)
root = np.zeros((n, n))
for l in range(1, n + 1):
for i in range(n - l + 2):
j = i + l - 1
set2(e, i, j, np.inf)
set2(w, i, j, get2(w, i, j - 1) + get1(p, j) + get1(q, j))
for r in range(i, j + 1):
t = get2(e, i, r - 1) + get2(e, r + 1, j) + get2(w, i, j)
if t < get2(e, i, j):
set2(e, i, j, t)
set2(root, i, j, r)
print(w)
print(e)
结果:
[[ 0.2 0.4 0.5 0.65 0.9 0. ]
[ 0. 0.2 0.3 0.45 0.7 0. ]
[ 0. 0. 0.1 0.25 0.5 0. ]
[ 0. 0. 0. 0.15 0.4 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.5 0.7 0.8 0.95 0. 0.3 ]]
[[ 0.2 0.6 0.8 1.2 1.95 0. ]
[ 0. 0.2 0.4 0.8 1.35 0. ]
[ 0. 0. 0.1 0.35 0.85 0. ]
[ 0. 0. 0. 0.15 0.55 0. ]
[ 0. 0. 0. 0. 0.25 0. ]
[ 0.7 1.2 1.5 2. 0. 0.3 ]]
答案 1 :(得分:0)
最后,我使用Series
'DataFrame
和index
对象初始化自定义columns
和import numpy as np
import pandas as pd
P = [0.15, 0.10, 0.05, 0.10, 0.20]
Q = [0.05, 0.10, 0.05, 0.05, 0.05, 0.10]
n = len(P)
p = pd.Series(P, index=range(1, n+1))
q = pd.Series(Q)
e = pd.DataFrame(np.diag(Q), index=range(1, n+2))
w = pd.DataFrame(np.diag(Q), index=range(1, n+2))
root = pd.DataFrame(np.zeros((n, n)), index=range(1, n+1), columns=range(1, n+1))
for l in range(1, n+1):
for i in range(1, n-l+2):
j = i+l-1
e.set_value(i, j, np.inf)
w.set_value(i, j, w.get_value(i, j-1) + p[j] + q[j])
for r in range(i, j+1):
t = e.get_value(i, r-1) + e.get_value(r+1, j) + w.get_value(i, j)
if t < e.get_value(i, j):
e.set_value(i, j, t)
root.set_value(i, j, r)
print(e)
print(w)
print(root)
来强制数组具有相同的索引在伪代码中。之后,伪代码几乎可以复制粘贴:
0 1 2 3 4 5
1 0.05 0.45 0.90 1.25 1.75 2.75
2 0.00 0.10 0.40 0.70 1.20 2.00
3 0.00 0.00 0.05 0.25 0.60 1.30
4 0.00 0.00 0.00 0.05 0.30 0.90
5 0.00 0.00 0.00 0.00 0.05 0.50
6 0.00 0.00 0.00 0.00 0.00 0.10
0 1 2 3 4 5
1 0.05 0.3 0.45 0.55 0.70 1.00
2 0.00 0.1 0.25 0.35 0.50 0.80
3 0.00 0.0 0.05 0.15 0.30 0.60
4 0.00 0.0 0.00 0.05 0.20 0.50
5 0.00 0.0 0.00 0.00 0.05 0.35
6 0.00 0.0 0.00 0.00 0.00 0.10
1 2 3 4 5
1 1.0 1.0 2.0 2.0 2.0
2 0.0 2.0 2.0 2.0 4.0
3 0.0 0.0 3.0 4.0 5.0
4 0.0 0.0 0.0 4.0 5.0
5 0.0 0.0 0.0 0.0 5.0
产生预期结果:
If
我仍然会对Numpy阵列的解决方案感兴趣,因为这对我来说似乎更优雅。
答案 2 :(得分:0)
库尔特, 谢谢你的帖子!您的问题是我发现的唯一可行的解决方案。我花了很多钱来处理索引。这是我使用numpy数组的实现。
import numpy as np
import math
def optimalBST(p,q,n):
e = np.zeros((n+1)**2).reshape(n+1,n+1)
w = np.zeros((n+1)**2).reshape(n+1,n+1)
root = np.zeros((n+1)**2).reshape(n+1,n+1)
# Initialization
for i in range(n+1):
e[i,i] = q[i]
w[i,i] = q[i]
for i in range(0,n):
root[i,i] = i+1
for l in range(1,n+1):
for i in range(0, n-l+1):
j = i+l
min_ = math.inf
w[i,j] = w[i,j-1] + p[j] + q[j]
for r in range(i,j):
t = e[i, r-1+1] + e[r+1,j] + w[i,j]
if t < min_:
min_ = t
e[i, j] = t
root[i, j-1] = r+1
root_pruned = np.delete(np.delete(root, n, 1), n, 0) # Trim last col & row.
print("------ e -------")
print(e)
print("------ w -------")
print(w)
print("----- root -----")
print(root_pruned)
def main():
p = [0,.15,.1,.05,.1,.2]
q = [.05,.1,.05,.05,.05,.1]
n = len(p)-1
optimalBST(p,q,n)
if __name__ == '__main__':
main()
输出:
------ e -------
[[0.05 0.45 0.9 1.25 1.75 2.75]
[0. 0.1 0.4 0.7 1.2 2. ]
[0. 0. 0.05 0.25 0.6 1.3 ]
[0. 0. 0. 0.05 0.3 0.9 ]
[0. 0. 0. 0. 0.05 0.5 ]
[0. 0. 0. 0. 0. 0.1 ]]
------ w -------
[[0.05 0.3 0.45 0.55 0.7 1. ]
[0. 0.1 0.25 0.35 0.5 0.8 ]
[0. 0. 0.05 0.15 0.3 0.6 ]
[0. 0. 0. 0.05 0.2 0.5 ]
[0. 0. 0. 0. 0.05 0.35]
[0. 0. 0. 0. 0. 0.1 ]]
----- root -----
[[1. 1. 2. 2. 2.]
[0. 2. 2. 2. 4.]
[0. 0. 3. 4. 5.]
[0. 0. 0. 4. 5.]
[0. 0. 0. 0. 5.]]