我对动态编程还很陌生,但是我正在努力变得更好。我有一本书的练习,它问我以下问题(略有删节):
您要根据集合{1,2,3,4,5,6}中的数字构造长度为N的序列。但是,您不能连续将数字i(i = 1、2、3、4、5、6)连续放置超过A [i]次,其中A是给定的数组。给定序列长度N(1 <= N <= 10 ^ 5)和约束数组A(1 <= A [i] <= 50),有多少个序列是可能的?
例如,如果A = {1,2,1,2,1,2}并且N = 2,这意味着您只能有一个连续的1,两个连续的2,一个连续的3,依此类推。像“ 11”一样无效,因为它有两个连续的1,而像“ 12”或“ 22”都有效。事实证明,这种情况的实际答案是33(共有36个两位数的序列,但是“ 11”,“ 33”和“ 55”都是无效的,给出了33)。
有人告诉我,解决此问题的一种方法是使用具有三种状态的动态编程。更具体地说,他们说要保留3d数组dp(i,j,k),其中i表示序列中我们当前所在的位置,j表示放置在位置i-1的元素,k表示该元素已在块中重复。他们还告诉我,对于过渡,我们可以将与j不同的每个元素放在i的位置,并且只有在A [j]> k时才能将j放入。
从理论上讲,这对我来说都是有意义的,但是我一直在努力实现这一点。除了初始化矩阵dp之外,我不知道如何开始实际的实现。通常,其他大多数练习都具有在矩阵中手动设置的某种“基本情况”,然后使用循环填充其他条目。
我想我特别困惑,因为这是一个3D阵列。
答案 0 :(得分:5)
让我们暂时不关心数组。让我们递归地实现它。令dp(i, j, k)
为长度i
,最后一个元素j
和k
在数组末尾连续出现j
的序列数。>
现在的问题是,我们如何递归地编写dp(i, j, k)
的解决方案。
我们知道我们要在j
时间中添加kth
,所以我们必须取每个长度为i - 1
的序列,并且j
出现{{1 }}次,然后向该序列添加另一个k - 1
。请注意,这只是j
。
但是如果dp(i - 1, j, k - 1)
怎么办?如果是这种情况,我们可以将一个k == 1
出现在每个长度为j
且不以i - 1
结尾的序列中。本质上,我们需要所有j
的总和,例如dp(i, x, k)
和A[x] >= k
。
这给出了我们的重复关系:
x != j
我们知道我们的答案将是长度为def dp(i, j, k):
# this is the base case, the number of sequences of length 1
# one if k is valid, otherwise zero
if i == 1: return int(k == 1)
if k > 1:
# get all the valid sequences [0...i-1] and add j to them
return dp(i - 1, j, k - 1)
if k == 1:
# get all valid sequences that don't end with j
res = 0
for last in range(len(A)):
if last == j: continue
for n_consec in range(1, A[last] + 1):
res += dp(i - 1, last, n_consec)
return res
的所有有效子序列,因此我们的最终答案是N
信不信由你,这是动态编程的基础。我们只是将主要问题分解为一系列子问题。当然,由于递归,现在我们的时间是指数级的。我们有两种方法可以降低这种情况:
缓存时,我们可以简单地跟踪每个(i,j,k)的结果,然后吐出再次调用时最初计算的结果。
使用数组。我们可以使用自下而上的dp重新实现这个想法,并拥有一个数组sum(dp(N, j, k) for j in range(len(A)) for k in range(1, A[j] + 1))
。我们所有的函数调用都只是成为for循环中的数组访问。请注意,使用此方法会迫使我们按照拓扑顺序对数组进行迭代,这可能很棘手。
答案 1 :(得分:1)
dp方法有两种:自顶向下和自底向上
从下至上,将终端箱填充到dp表中,然后使用for循环从中建立起来。让我们考虑自下而上的算法来生成斐波那契序列。我们设置dp[0] = 1
和dp[1] = 1
并从i = 2 to n
运行for循环。
在自上而下的方法中,我们从问题的“顶部”视图开始,然后从那里开始。考虑使用递归函数获得第n个斐波那契数:
def fib(n):
if n <= 1:
return 1
if dp[n] != -1:
return dp[n]
dp[n] = fib(n - 1) + fib(n - 2)
return dp[n]
在这里,我们没有填写完整的表格,而只是填写我们遇到的情况。
为什么我要谈论这两种类型是因为当您开始学习dp时,通常很难提出自下而上的方法(就像您要尝试的那样)。发生这种情况时,首先您想提出一种自上而下的方法,然后尝试从中获得自下而上的解决方案。
因此,让我们首先创建一个递归dp函数:
# let m be size of A
# initialize dp table with all values -1
def solve(i, j, k, n, m):
# first write terminal cases
if k > A[j]:
# this means sequence is invalid. so return 0
return 0
if i >= n:
# this means a valid sequence.
return 1
if dp[i][j][k] != -1:
return dp[i][j][k]
result = 0
for num = 1 to m:
if num == j:
result += solve(i + 1, num, k + 1, n)
else:
result += solve(i + 1, num, 1, n)
dp[i][j][k] = result
return dp[i][j][k]
所以我们知道什么是极端情况。我们创建一个大小为dp [n + 1] [m] [50]的dp表。用所有值0而不是-1初始化它。
所以我们可以自下而上地做:
# initially all values in table are zero. With loop below, we set the valid endings as 1.
# So any state trying to reach valid terminal states will get 1, but invalid states will
# return the values 0
for num = 1 to m:
for occour = 1 to A[num]:
dp[n][num][occour] = 1
# now to build up from bottom, we start by filling n-1 th position
for i = n-1 to 1:
for num = 1 to m:
for occour = 1 to A[num]:
for next_num = 1 to m:
if next_num != num:
dp[i][num][occour] += dp[i + 1][next_num][1]
else:
dp[i][num][occour] += dp[i + 1][num][occour + 1]
答案将是:
sum = 0
for num = 1 to m:
sum += dp[1][num][1]
我确信必须有一些更优雅的dp解决方案,但是我相信这可以回答您的问题。请注意,我认为k是第j个数字被连续重复的次数,如果我错了,请纠正我。
编辑:
在给定的约束下,在最坏的情况下,表的大小将为10 ^ 5 * 6 * 50 = 3e7。这将是> 100MB。它是可行的,但可以认为它占用了过多的空间(我认为某些内核不允许进程使用太多堆栈空间)。减少它的一种方法是使用 hash-map 而不是使用自顶向下方法的数组,因为自顶向下不会访问所有状态。在这种情况下,多数情况下都是如此,例如,如果A [1]为2,则所有其他其他状态(其中1出现的次数超过两次,则不需要存储)。当然,如果A [i]具有较大的值,例如[50,50,50,50,50,50],这将不会节省太多空间。另一种方法是稍微修改一下我们的方法。我们实际上不需要存储维度k,即j连续出现的时间:
dp[i][j] = no of ways from i-th position if (i - 1)th position didn't have j and i-th position is j.
然后,我们需要将算法修改为:
def solve(i, j):
if i == n:
return 1
if i > n:
return 0
if dp[i][j] != -1
return dp[i][j]
result = 0
# we will first try 1 consecutive j, then 2 consecutive j's then 3 and so on
for count = 1 to A[j]:
for num = 1 to m:
if num != j:
result += solve(i + count, num)
dp[i][j] = result
return dp[i][j]
这种方法会将我们的空间复杂度降低到O(10 ^ 6)〜= 2mb,而时间复杂度仍然相同:O(N * 6 * 50)