如何使用动态编程解决此问题?

时间:2020-08-09 07:04:17

标签: algorithm math optimization dynamic-programming memoization

给出一个数字列表,例如[4 5 2 3],我需要最大化根据以下规则集获得的总和:

  1. 我需要从列表中选择一个号码,该号码将被删除。 例如。选择2将其列表显示为[4 5 3]。
  2. 如果要删除的号码有两个邻居,那么我应该得到此选择的结果,作为当前选定号码与其邻居之一的乘积,并且此乘积与另一个邻居相加。例如:如果选择2,则选择的结果为2 * 5 + 3。
  3. 如果我选择一个只有一个邻居的数字,那么结果就是所选数字与其邻居的乘积。
  4. 当他们只剩下一个数字时,它才被添加到结果中。

遵循这些规则,我需要选择数字以使结果最大化。

对于上面的列表,如果选择的顺序为4-> 2-> 3-> 5,则得出的总和为53,这是最大值。

我包含一个程序,该程序使您可以将元素集作为输入传递,并给出所有可能的总和,并且还指示最大和。

这里是a link

import itertools

l = [int(i) for i in input().split()]
p = itertools.permutations(l) 

c, cs = 1, -1
mm = -1
for i in p:
    var, s = l[:], 0
    print(c, ':', i)
    c += 1
    
    for j in i:
        print(' removing: ', j)
        pos = var.index(j)
        if pos == 0 or pos == len(var) - 1:
            if pos == 0 and len(var) != 1:
                s += var[pos] * var[pos + 1]
                var.remove(j)
            elif pos == 0 and len(var) == 1:
                s += var[pos]
                var.remove(j)
            if pos == len(var) - 1 and pos != 0:
                s += var[pos] * var[pos - 1]
                var.remove(j)
        else:
            mx = max(var[pos - 1], var[pos + 1])
            mn = min(var[pos - 1], var[pos + 1])
            
            s += var[pos] * mx + mn
            var.remove(j)
        
        if s > mm:
            mm = s
            cs = c - 1
        print(' modified list: ', var, '\n  sum:', s)

print('MAX SUM was', mm, ' at', cs)

1 个答案:

答案 0 :(得分:0)

考虑该问题的4个变体:那些消耗每个元素的对象,以及不消耗左侧,右侧或左右元素的对象。

在每种情况下,您都可以考虑删除最后一个元素,这会将问题分解为1个或2个子问题。

这解决了O(n ^ 3)时间的问题。这是一个解决问题的python程序。 solve_的4个变体分别与一个端点,一个端点或另一个端点不固定。毫无疑问,该程序可以减少(重复很多)。

def solve_00(seq, n, m, cache):
    key = ('00', n, m)
    if key in cache:
        return cache[key]
    assert m >= n
    if n == m:
        return seq[n]
    best = -1e9
    for i in range(n, m+1):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[i])
    cache[key] = best
    return best


def solve_01(seq, n, m, cache):
    key = ('01', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n, m):
        left = solve_01(seq, n, i, cache) if i > n else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[m])
    cache[key] = best
    return best

def solve_10(seq, n, m, cache):
    key = ('10', n, m)
    if key in cache:
        return cache[key]
    assert m >= n + 1
    if m == n + 1:
        return seq[n] * seq[m]
    best = -1e9
    for i in range(n+1, m+1):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_10(seq, i, m, cache) if i < m else 0
        best = max(best, left + right + seq[n] * seq[i])
    cache[key] = best
    return best

def solve_11(seq, n, m, cache):
    key = ('11', n, m)
    if key in cache:
        return cache[key]   
    assert m >= n + 2
    if m == n + 2:
        return max(seq[n] * seq[n+1] + seq[n+2], seq[n] + seq[n+1] * seq[n+2])
    best = -1e9
    for i in range(n + 1, m):
        left = solve_11(seq, n, i, cache) if i > n + 1 else 0
        right = solve_11(seq, i, m, cache) if i < m - 1 else 0
        best = max(best, left + right + seq[i] * seq[n] + seq[m], left + right + seq[i] * seq[m] + seq[n])
    cache[key] = best
    return best

for c in [[1, 1, 1], [4, 2, 3, 5], [1, 2], [1, 2, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]:
    print(c, solve_00(c, 0, len(c)-1, dict()))