在玩游戏时,我经常编写简单的递归函数,如下所示:
def f(a,b):
if a>=0 and b>=0:
return min( f(a-1,b) , f(b,a-1) ) # + some cost that depends on a,b
else:
return 0
(例如,在计算加权编辑距离或评估递归定义的数学公式时。)
然后我使用memoizing装饰器自动缓存结果。
当我尝试类似f(200,10)的东西时,我得到:
RuntimeError: maximum recursion depth exceeded
这是预期的,因为递归实现耗尽了Python的堆栈空间/递归限制。
我通常通过以下方式解决此问题:
但我发现所有这些都很容易出错。
有没有办法编写一个@Bigstack装饰器来模拟真正大堆栈的效果?
请注意,我的函数通常会进行多次递归函数调用,因此这与尾递归不同 - 我确实希望保存堆栈中每个函数的所有内部状态。
我一直在考虑使用生成器表达式列表作为我的堆栈。通过探测堆栈帧,我可以在递归调用函数时解决问题,然后触发异常以返回装饰器代码。但是,我无法找到一种方法将这些想法粘合在一起,以制作任何实际有用的东西。
或者,我可以尝试访问该函数的抽象语法树,并尝试将调用转换为递归函数以生成语句,但这看起来似乎朝着错误的方向前进。
有什么建议吗?
修改
当然看起来我在滥用Python,但我一直在考虑的另一种方法是为每个块使用不同的线程,比如500个堆栈帧,然后在每个连续的线程对之间插入队列 - 一个用于参数的队列,以及返回值的另一个队列。 (每个队列中最多只有一个条目。)我认为这可能由于某种原因不起作用 - 但我可能只会在我尝试实现它之后找出原因。
答案 0 :(得分:5)
要绕过递归限制,您可以捕获RuntimeError
异常以检测何时用完堆栈空间,然后返回一个continuation-ish函数,该函数在调用时重新启动递归。你没空间的地方。调用此(及其返回值,依此类推),直到获得值,然后从顶部再次尝试。一旦你记住了较低级别,较高级别将不会遇到递归限制,因此最终这将起作用。将重复调用直到它的工作放在包装函数中。基本上它是你的热身缓存理念的懒惰版本。
这是一个简单的递归“添加从1到n的数字”功能的例子。
import functools
def memoize(func):
cache = {}
@functools.wraps(func)
def wrapper(*args, **kwargs):
key = args, tuple(sorted(kwargs.items()))
if key in cache:
return cache[key]
else:
result = func(*args, **kwargs)
if not callable(result):
cache[key] = result
return result
return wrapper
@memoize
def _addup(n):
if n < 2:
return n
else:
try:
result = _addup(n - 1)
except RuntimeError:
return lambda: _addup(n)
else:
return result if callable(result) else result + n
def addup(n):
result = _addup(n)
while callable(result):
while callable(result):
result = result()
result = _addup(n)
return result
assert addup(5000) == sum(xrange(5001))
我没有将lambda函数一直返回到调用链中,而是提出了一个异常短路,这样可以提高性能并简化代码:
# memoize function as above, or you can probably use functools.lru_cache
class UnwindStack(Exception):
pass
@memoize
def _addup(n):
if n < 2:
return n
else:
try:
return _addup(n - 1) + n
except RuntimeError:
raise UnwindStack(lambda: _addup(n))
def _try(func, *args, **kwargs):
try:
return func(*args, **kwargs)
except UnwindStack as e:
return e[0]
def addup(n):
result = _try(_addup, n)
while callable(result):
while callable(result):
result = _try(result)
result = _try(_addup, n)
return result
但这仍然非常不优雅,但仍然有相当大的开销,我无法想象你是如何制作装饰品的。我想,Python并不适合这种事情。
答案 1 :(得分:2)
这是一个使用生成器表达式列表作为堆栈的实现:
def run_stackless(frame):
stack, return_stack = [(False, frame)], []
while stack:
active, frame = stack.pop()
action, res = frame.send(return_stack.pop() if active else None)
if action == 'call':
stack.extend([(True, frame), (False, res)])
elif action == 'tail':
stack.append((False, res))
elif action == 'return':
return_stack.append(res)
else:
raise ValueError('Unknown action', action)
return return_stack.pop()
要使用它,您需要根据以下规则转换递归函数:
return expr -> yield 'return', expr
recursive_call(args...) -> (yield 'call', recursive_call(args...))
return recursive_call(args...) -> yield 'tail', recursive_call(args...)
例如,使用成本函数a * b
,您的函数将变为:
def f(a,b):
if a>=0 and b>=0:
yield 'return', min((yield 'call', f(a-1,b)),
(yield 'call', f(b,a-1))) + (a * b)
else:
yield 'return', 0
测试:
In [140]: run_stackless(g(30, 4))
Out[140]: 410
在Python 2.6.2中,与直接调用相比,它的性能提升约为8-10倍。
tail
动作用于尾递归:
def factorial(n):
acc = [1]
def fact(n):
if n == 0:
yield 'return', 0
else:
acc[0] *= n
yield 'tail', fact(n - 1)
run_stackless(fact(n))
return acc[0]
转换为生成器递归样式相当容易,可能会以字节码黑客的方式完成。
答案 2 :(得分:1)
这种方法将记忆和增加的堆栈深度结合到一个装饰器中。
我生成一个线程池,每个线程负责堆栈的64个级别 线程只创建一次并重新启动(但目前从未删除)。
队列用于在线程之间传递信息,但请注意,只有与当前堆栈深度相对应的线程实际上才能完成工作。
我的实验表明,这为简单的递归函数增加了大约10%的开销(对于更复杂的函数应该更少)。
import threading,Queue
class BigstackThread(threading.Thread):
def __init__(self,send,recv,func):
threading.Thread.__init__( self )
self.daemon = True
self.send = send
self.recv = recv
self.func = func
def run(self):
while 1:
args = self.send.get()
v = self.func(*args)
self.recv.put(v)
class Bigstack(object):
def __init__(self,func):
self.func = func
self.cache = {}
self.depth = 0
self.threadpool = {}
def __call__(self,*args):
if args in self.cache:
return self.cache[args]
self.depth+=1
if self.depth&63:
v = self.func(*args)
else:
T=self.threadpool
if self.depth not in T:
send = Queue.Queue(1)
recv = Queue.Queue(1)
t = BigstackThread(send,recv,self)
T[self.depth] = send,recv,t
t.start()
else:
send,recv,_ = T[self.depth]
send.put(args)
v = recv.get()
self.depth-=1
self.cache[args]=v
return v
@Bigstack
def f(a,b):
if a>=0 and b>=0:
return min(f(a-1,b),f(b-1,a))+1
return 0