关于DP的性能讨论

时间:2013-08-16 07:16:46

标签: python performance dynamic-programming

看下面的代码,我用两种方法来解决问题(简单的递归和DP)。为什么DP方式更慢?

你的建议是什么?

#!/usr/local/bin/python2.7
# encoding: utf-8

问题:有一个正整数的数组。给定正整数S,\ 找到数字总和为S的组合总数。

方法I:

def find_sum_recursive(number_list, sum_to_find):
    count = 0

    for i in range(len(number_list)):        
        sub_sum = sum_to_find - number_list[i]
        if sub_sum < 0:
            continue
        elif sub_sum == 0:
            count += 1
            continue
        else:            
            sub_list = number_list[i + 1:]            
            count += find_sum_recursive(sub_list, sub_sum)       
    return count

方法II:

def find_sum_DP(number_list, sum_to_find):
    count = 0

    if(0 == sum_to_find):
        count = 1
    elif([] != number_list and sum_to_find > 0):   
        count = find_sum_DP(number_list[:-1], sum_to_find) + find_sum_DP(number_list[:-1], sum_to_find - number_list[:].pop())     

    return count

运行它:

def main(argv=None):  # IGNORE:C0111
    number_list = [5, 5, 10, 3, 2, 9, 8]
    sum_to_find = 15
    input_setup = ';number_list = [5, 5, 10, 3, 2, 9, 8, 7, 6, 4, 3, 2, 9, 5, 4, 7, 2, 8, 3];sum_to_find = 15'

    print 'Calculating...'
    print 'recursive starting'
    count = find_sum_recursive(number_list, sum_to_find)
    print timeit.timeit('count = find_sum_recursive(number_list, sum_to_find)', setup='from __main__ import find_sum_recursive' + input_setup, number=10)
    cProfile.run('find_sum_recursive(' + str(number_list) + ',' + str(sum_to_find) + ')')
    print 'recursive ended:', count    
    print 'DP starting'
    count_DP = find_sum_DP(number_list, sum_to_find)
    print timeit.timeit('count_DP = find_sum_DP(number_list, sum_to_find)', setup='from __main__ import find_sum_DP' + input_setup, number=10)
    cProfile.run('find_sum_DP(' + str(number_list) + ',' + str(sum_to_find) + ')')
    print 'DP ended:', count_DP        
    print 'Finished.'    

if __name__ == '__main__':
    sys.exit(main())

我重新编写方法II,现在就是:

def find_sum_DP(number_list, sum_to_find):
    count = [[0 for i in xrange(0, sum_to_find + 1)] for j in xrange(0, len(number_list) + 1)]    

    for i in range(len(number_list) + 1):
        for j in range(sum_to_find + 1):            
            if (0 == i and 0 == j):                
                count[i][j] = 1            
            elif (i > 0 and j > 0):                
                if (j > number_list[i - 1]):                    
                    count[i][j] = count[i - 1][j] + count[i - 1][j - number_list[i - 1]]
                elif(j < number_list[i - 1]):
                    count[i][j] = count[i - 1][j]
                else:
                    count[i][j] = count[i - 1][j] + 1                                
            else:                
                count[i][j] = 0

    return count[len(number_list)][sum_to_find]

比较方法I&amp; II:

Calculating...
recursive starting
0.00998711585999
         92 function calls (63 primitive calls) in 0.000 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <string>:1(<module>)
     30/1    0.000    0.000    0.000    0.000 FindSum.py:18(find_sum_recursive)
       30    0.000    0.000    0.000    0.000 {len}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
       30    0.000    0.000    0.000    0.000 {range}


recursive ended: 6
DP starting
0.00171685218811
         15 function calls in 0.000 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000    0.000    0.000 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 FindSum.py:33(find_sum_DP)
        3    0.000    0.000    0.000    0.000 {len}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
        9    0.000    0.000    0.000    0.000 {range}


DP ended: 6
Finished.

2 个答案:

答案 0 :(得分:6)

如果你正在使用iPython,%prun就是你的朋友。

看一下递归版的输出:

         2444 function calls (1631 primitive calls) in 0.002 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    814/1    0.002    0.000    0.002    0.002 <ipython-input-1-7488a6455e38>:1(find_sum_recursive)
      814    0.000    0.000    0.000    0.000 {range}
      814    0.000    0.000    0.000    0.000 {len}
        1    0.000    0.000    0.002    0.002 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

现在,对于DP版本:

         10608 function calls (3538 primitive calls) in 0.007 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   7071/1    0.007    0.000    0.007    0.007 <ipython-input-15-3535e3ab26eb>:1(find_sum_DP)
     3535    0.001    0.000    0.001    0.000 {method 'pop' of 'list' objects}
        1    0.000    0.000    0.007    0.007 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
7071比814高出一点!

你的问题在于你的动态编程方法不是动态编程!动态编程的重点是,当你遇到重叠子问题的问题时,就像你在这里一样,你存储每个子问题的结果,然后当您再次需要结果时,从该商店获取它而不是重新计算。您的代码不会这样做:每次调用find_sum_DP时,即使已经完成相同的计算,也会重新计算。结果是你的_DP方法实际上不仅是递归的,而是递归函数调用而不是递归方法。

(我目前正在编写DP版本来演示)

编辑:

我需要添加警告,虽然我应该更多地了解动态编程,但我非常尴尬。我也是在深夜和深夜写这个,有点像我自己的练习。不过,这里是函数的动态编程实现:

import numpy as np
def find_sum_realDP( number_list, sum_to_find ):
    memo = np.zeros( (len(number_list),sum_to_find+1) ,dtype=np.int)-1
    # This will store our results. memo[l][n] will give us the result
    # for number_list[0:l+1] and a sum_to_find of n. If it hasn't been
    # calculated yet, it will give us -1. This is not at all efficient
    # storage, but isn't terribly bad.

    # Now that we have that, we'll call the real function. Instead of modifying
    # the list and making copies or views, we'll keep the same list, and keep
    # track of the index we're on (nli).
    return find_sum_realDP_do( number_list, len(number_list)-1, sum_to_find, memo ),memo

def find_sum_realDP_do( number_list, nli, sum_to_find, memo ):
    # Our count is 0 by default.
    ret = 0

    # If we aren't at the sum to find yet, do we have any numbers left after this one?
    if ((sum_to_find > 0) and nli>0):
        # Each of these checks to see if we've already stored the result of the calculation.
        # If so, we use that, if not, we calculate it.
        if memo[nli-1,sum_to_find]>=0:
            ret += memo[nli-1,sum_to_find]
        else:
            ret += find_sum_realDP_do(number_list, nli-1, sum_to_find, memo)

        # This one is a bit tricky, and was a bug when I first wrote it. We don't want to
        # have a negative sum_to_find, because that will be very bad; we'll start using results
        # from other places in memo because it will wrap around.
        if (sum_to_find-number_list[nli]>=0) and memo[nli-1,sum_to_find-number_list[nli]]>=0:
            ret += memo[nli-1,sum_to_find-number_list[nli]]
        elif (sum_to_find-number_list[nli]>=0):
            ret += find_sum_realDP_do(number_list, nli-1, sum_to_find-number_list[nli], memo)

    # Do we not actually have any sum to find left?     
    elif (0 == sum_to_find):
        ret = 1

    # If we only have one number left, will it get us there?
    elif (nli == 0) and (sum_to_find-number_list[nli] == 0 ):
        ret = 1

    # Store our result.
    memo[nli,sum_to_find] = ret

    # Return our result.
    return ret        

请注意,这会使用numpy。你很可能没有安装这个,但我不确定如何在没有它的情况下用Python编写一个合理执行的动态编程算法;我不认为Python列表可以接近Numpy阵列的性能。另请注意,这与您的代码处理零的方式不同,因此我只是说这个代码是针对数字列表中的非零正整数而不是调试它。现在,通过这种算法,分析为我们提供了:

         243 function calls (7 primitive calls) in 0.001 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    237/1    0.001    0.000    0.001    0.001 <ipython-input-155-4a624e5a99b7>:9(find_sum_realDP_do)
        1    0.000    0.000    0.001    0.001 <ipython-input-155-4a624e5a99b7>:1(find_sum_realDP)
        1    0.000    0.000    0.000    0.000 {numpy.core.multiarray.zeros}
        1    0.000    0.000    0.001    0.001 <string>:1(<module>)
        2    0.000    0.000    0.000    0.000 {len}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

243比递归版更好!但是你的示例数据足够小,以至于真的显示出动态编程算法的优势。

让我们尝试nlist2 = [7, 6, 2, 3, 7, 7, 2, 7, 4, 2, 4, 5, 6, 1, 7, 4, 6, 3, 2, 1, 1, 1, 4, 2, 3, 5, 2, 4, 4, 2, 4, 5, 4, 2, 1, 7, 6, 6, 1, 5, 4, 5, 3, 2, 3, 7, 1, 7, 6, 6],使用相同的sum_to_find=15。这有50个值,900206种方法可以获得15 ...

使用find_sum_recursive

         3335462 function calls (2223643 primitive calls) in 14.137 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
1111820/1   13.608    0.000   14.137   14.137 <ipython-input-46-7488a6455e38>:1(find_sum_recursive)
  1111820    0.422    0.000    0.422    0.000 {range}
  1111820    0.108    0.000    0.108    0.000 {len}
        1    0.000    0.000   14.137   14.137 <string>:1(<module>)
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

现在使用find_sum_realDP

         736 function calls (7 primitive calls) in 0.007 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    730/1    0.007    0.000    0.007    0.007 <ipython-input-155-4a624e5a99b7>:9(find_sum_realDP_do)
        1    0.000    0.000    0.007    0.007 <ipython-input-155-4a624e5a99b7>:1(find_sum_realDP)
        1    0.000    0.000    0.000    0.000 {numpy.core.multiarray.zeros}
        1    0.000    0.000    0.007    0.007 <string>:1(<module>)
        2    0.000    0.000    0.000    0.000 {len}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}

因此,我们的呼叫次数不到1/1000,并且运行时间不到1/2000。当然,您使用的列表越大,DP算法的效果就越好。在我的计算机上,使用sum_to_find为15运行以及从1到8的600个随机数列表,realDP只需要0.09秒,并且函数调用少于10,000;在这一点上,我正在使用的64位整数开始溢出,我们有各种其他问题。毋庸置疑,递归算法永远无法在计算机停止运行之前处理任何接近该大小的列表,无论是内部材料发生故障还是宇宙热死亡。

答案 1 :(得分:1)

有一件事是您的代码列出了许多列表复制。如果它只是传递索引或索引来定义“窗口视图”而不是全部复制列表,那将会更快。对于第一种方法,您可以轻松添加参数starting_index并在for循环中使用它。在第二种方法中,您编写number_list[:].pop()并复制整个列表,以获取您可以简单地执行number_list[-1]的最后一个元素。你也可以添加一个参数ending_index并在你的测试中使用它(len(number_list) == ending_index而不是number_list != [],顺便说一句,即使只是普通number_list比测试空列表更好。“ / p>