如何减少程序的运行时间?

时间:2015-02-28 09:38:34

标签: python recursion runtime

每当我尝试运行程序时,程序都需要几分钟。 cycle_length指的是Collat​​z周期长度(Collatz conjecture声称,无论您从哪个数字开始,如果执行给定的计算,您最终都会达到1max_length计算ij之间的整数的循环长度,以确定哪个数字产生最长的循环。

def cycle_length(n):
    if n == 1:
        return n    
    elif n % 2 == 0:
        return cycle_length(n//2) + 1
    else:
        return cycle_length(3*n + 1) + 1

def max_length(i,j):
    mxl = cycle_length(i)
    mxn = i
    while i <= j:
        start = time.time()
        y = cycle_length(i)
        if y > mxl:
            mxl = y
            mxn = i
        i += 1
    return (mxn,mxl)

print(max_length(1, 10**6)) 

我想将程序从1迭代到10**6。是否有任何有效的方法可以使程序更快(10秒以下)?

3 个答案:

答案 0 :(得分:0)

Collat​​z序列是使用"dynamic programming" techniques的绝佳机会。来自给定数字n的序列的长度将始终相同,因此您可以存储该结果,并在将来某个序列中再次到达n时使用该结果。例如,使用装饰器"memoize"结果:

def memo(func):
    def wrapper(*args):
        if args not in wrapper.cache:
            wrapper.cache[args] = func(*args)
        return wrapper.cache[args]
    wrapper.cache = {}
    return wrapper

@memo
def collatz_length(n):
    if n == 1:
        return 1
    elif n % 2:
        return 1 + collatz_length((3 * n) + 1)
    return 1 + collatz_length(n // 2)

现在,如果我们为n的几个起始值运行程序,您可以看到cache充满了预先计算的结果,从而加快了未来的调用(以存储空间为代价) - 经典的编程权衡):

>>> for x in range(1, 11):
    print x, collatz_length(x)


1 1
2 2
3 8
4 3
5 6
6 9
7 17
8 4
9 20
10 7
>>> collatz_length.cache
{(34,): 14, (9,): 20, (11,): 15, (13,): 10, (26,): 11, (1,): 1, (28,): 19 
 (3,): 8, (5,): 6, (16,): 5, (7,): 17, (20,): 8, (22,): 16, (8,): 4, (10,): 7,
 (14,): 18, (52,): 12, (2,): 2, (40,): 9, (4,): 3, (6,): 9, (17,): 13}

这可以给你大约1.5s的结果:

>>> from timeit import timeit
>>> timeit('max(collatz_length(x+1) for x in range(10**6))',
           setup='from __main__ import collatz_length',
           number=1)
1.5072860717773438

答案 1 :(得分:0)

你做了很多时间相同的计算。为避免这种情况,我们可以存储值:

cache = {1 : 1}

def cycle_length(n):
    if n in cache.keys():
        return cache[n]
    elif n % 2 == 0:
        x = cycle_length(n//2) + 1
    else:
        x = cycle_length(3*n + 1) + 1
    cache[n] = x
    return x

def max_length(i,j):
    mxl = cycle_length(i)
    mxn = i
    while i <= j:
        y = cycle_length(i)
        if y > mxl:
            mxl = y
            mxn = i
        i += 1
    return (mxn,mxl)

print(max_length(1, 10**6))

在我的机器上,您的代码以29.9秒的速度运行。我的代码给出了相同的结果,并在1.8秒内运行。


编辑:我写了一个脚本来比较3个答案的效率。

from functools import lru_cache
import time

cache = {1 : 1}
def cycle_length1(n):
    if n in cache.keys():
        return cache[n]
    elif n % 2 == 0:
        x = cycle_length1(n//2) + 1
    else:
        x = cycle_length1(3*n + 1) + 1
    cache[n] = x
    return x

def memo(func):
    def wrapper(*args):
        if args not in wrapper.cache:
            wrapper.cache[args] = func(*args)
        return wrapper.cache[args]
    wrapper.cache = {}
    return wrapper

@memo
def cycle_length2(n):
    if n == 1:
        return 1
    elif n % 2:
        return 1 + cycle_length2((3 * n) + 1)
    return 1 + cycle_length2(n // 2)

@lru_cache(maxsize=None)
def cycle_length3(n):
    if n == 1:
        return n
    elif n % 2 == 0:
        return cycle_length3(n//2) + 1
    else:
        return cycle_length3(3*n + 1) + 1

def max_length(f, i=1, j=10**6):
    mxl = f(i)
    mxn = i
    while i <= j:
        y = f(i)
        if y > mxl:
            mxl = y
            mxn = i
        i += 1
    return (mxn,mxl)

for f in [cycle_length1, cycle_length2, cycle_length3]:
    tic = time.time()
    print(max_length(f))
    print("%s\n" % (time.time() - tic))

结果:

(837799, 525)
1.5899293422698975

(837799, 525)
4.623902320861816

(837799, 525)
4.403488874435425

似乎这个答案效率最高(我认为开销较小,我们操纵一个简单的字典)。

答案 2 :(得分:0)

最简单的方法是记住cycle_length的值。如果您确实使用的是Python 3(3.3 +),则可以使用cycle_length装饰您的函数lru_cache(maxsize=None);这个程序在我的笔记本电脑上运行时间为5秒:

from functools import lru_cache

@lru_cache(maxsize=None)
def cycle_length(n):
    if n == 1:
        return n    
    elif n % 2 == 0:
        return cycle_length(n//2) + 1
    else:
        return cycle_length(3*n + 1) + 1

def max_length(i,j):
    mxl = cycle_length(i)
    mxn = i
    while i <= j:
        start = time.time()
        y = cycle_length(i)
        if y > mxl:
            mxl = y
            mxn = i
        i += 1
    return (mxn,mxl)

print(max_length(1, 10**6))