当参数保持不变时,最大限度地减少代价高昂的函数调用次数(python)

时间:2013-12-06 23:11:21

标签: python optimization

假设有一个函数costly_function_a(x),使得:

  1. 执行时间非常昂贵;
  2. 只要向其输入相同的x,它就会返回相同的输出;和
  3. 除了返回输出外,它不执行“附加任务”。
  4. 在这些条件下,我们可以将结果存储在临时变量中,然后使用该变量进行这些计算,而不是使用相同的x连续两次调用该函数。

    现在假设有一些函数(f(x)g(x)h(x)在下面的示例中)调用costly_function_a(x),并且其中一些函数可能会调用每个函数其他人(在下面的示例中,g(x)h(x)都会调用f(x))。在这种情况下,使用上述简单方法仍会导致使用相同的costly_function_a(x)重复调用x(请参阅下面的OkayVersion)。我确实找到了一种最小化调用次数的方法,但它“丑陋”(见下面的FastVersion)。有没有更好的方法来做到这一点?

    #Dummy functions representing extremely slow code.
    #The goal is to call these costly functions as rarely as possible.
    def costly_function_a(x):
        print("costly_function_a has been called.")
        return x #Dummy operation.
    def costly_function_b(x):
        print("costly_function_b has been called.")
        return 5.*x #Dummy operation.
    
    #Simplest (but slowest) implementation.
    class SlowVersion:
        def __init__(self,a,b):
            self.a = a
            self.b = b
        def f(self,x): #Dummy operation.
            return self.a(x) + 2.*self.a(x)**2
        def g(self,x): #Dummy operation.
            return self.f(x) + 0.7*self.a(x) + .1*x
        def h(self,x): #Dummy operation.
            return self.f(x) + 0.5*self.a(x) + self.b(x) + 3.*self.b(x)**2
    
    #Equivalent to SlowVersion, but call the costly functions less often.
    class OkayVersion:
        def __init__(self,a,b):
            self.a = a
            self.b = b
        def f(self,x): #Same result as SlowVersion.f(x)
            a_at_x = self.a(x)
            return a_at_x + 2.*a_at_x**2
        def g(self,x): #Same result as SlowVersion.g(x)
            return self.f(x) + 0.7*self.a(x) + .1*x
        def h(self,x): #Same result as SlowVersion.h(x)
            a_at_x = self.a(x)
            b_at_x = self.b(x)
            return self.f(x) + 0.5*a_at_x + b_at_x + 3.*b_at_x**2
    
    #Equivalent to SlowVersion, but calls the costly functions even less often.
    #Is this the simplest way to do it? I am aware that this code is highly
    #redundant. One could simplify it by defining some factory functions...
    class FastVersion:
        def __init__(self,a,b):
            self.a = a
            self.b = b
        def f(self, x, _at_x=None): #Same result as SlowVersion.f(x)
            if _at_x is None:
                _at_x = dict()
            if 'a' not in _at_x:
                _at_x['a'] = self.a(x)
            return _at_x['a'] + 2.*_at_x['a']**2
        def g(self, x, _at_x=None): #Same result as SlowVersion.g(x)
            if _at_x is None:
                _at_x = dict()
            if 'a' not in _at_x:
                _at_x['a'] = self.a(x)
            return self.f(x,_at_x) + 0.7*_at_x['a'] + .1*x
        def h(self,x,_at_x=None): #Same result as SlowVersion.h(x)
            if _at_x is None:
                _at_x = dict()
            if 'a' not in _at_x:
                _at_x['a'] = self.a(x)
            if 'b' not in _at_x:
                _at_x['b'] = self.b(x)
            return self.f(x,_at_x) + 0.5*_at_x['a'] + _at_x['b'] + 3.*_at_x['b']**2
    
    if __name__ == '__main__':
    
        slow = SlowVersion(costly_function_a,costly_function_b)
        print("Using slow version.")
        print("f(2.) = " + str(slow.f(2.)))
        print("g(2.) = " + str(slow.g(2.)))
        print("h(2.) = " + str(slow.h(2.)) + "\n")
    
        okay = OkayVersion(costly_function_a,costly_function_b)
        print("Using okay version.")
        print("f(2.) = " + str(okay.f(2.)))
        print("g(2.) = " + str(okay.g(2.)))
        print("h(2.) = " + str(okay.h(2.)) + "\n")
    
        fast = FastVersion(costly_function_a,costly_function_b)
        print("Using fast version 'casually'.")
        print("f(2.) = " + str(fast.f(2.)))
        print("g(2.) = " + str(fast.g(2.)))
        print("h(2.) = " + str(fast.h(2.)) + "\n")
    
        print("Using fast version 'optimally'.")
        _at_x = dict()
        print("f(2.) = " + str(fast.f(2.,_at_x)))
        print("g(2.) = " + str(fast.g(2.,_at_x)))
        print("h(2.) = " + str(fast.h(2.,_at_x)))
        #Of course, one must "clean up" _at_x before using a different x...
    

    此代码的输出为:

    Using slow version.
    costly_function_a has been called.
    costly_function_a has been called.
    f(2.) = 10.0
    costly_function_a has been called.
    costly_function_a has been called.
    costly_function_a has been called.
    g(2.) = 11.6
    costly_function_a has been called.
    costly_function_a has been called.
    costly_function_a has been called.
    costly_function_b has been called.
    costly_function_b has been called.
    h(2.) = 321.0
    
    Using okay version.
    costly_function_a has been called.
    f(2.) = 10.0
    costly_function_a has been called.
    costly_function_a has been called.
    g(2.) = 11.6
    costly_function_a has been called.
    costly_function_b has been called.
    costly_function_a has been called.
    h(2.) = 321.0
    
    Using fast version 'casually'.
    costly_function_a has been called.
    f(2.) = 10.0
    costly_function_a has been called.
    g(2.) = 11.6
    costly_function_a has been called.
    costly_function_b has been called.
    h(2.) = 321.0
    
    Using fast version 'optimally'.
    costly_function_a has been called.
    f(2.) = 10.0
    g(2.) = 11.6
    costly_function_b has been called.
    h(2.) = 321.0
    

    请注意,我不想“存储”过去使用的x的所有值的结果(因为这需要太多内存)。此外,我不希望函数返回(f,g,h)形式的元组,因为在某些情况下我只需要f(因此无需评估costly_function_b)。< / p>

2 个答案:

答案 0 :(得分:11)

您正在寻找的是LRU缓存;只缓存最近使用的项目,限制内存使用以平衡调用成本和内存要求。

由于使用x的不同值调用昂贵的函数,因此缓存了多个返回值(每个唯一x值),并且当最近使用的缓存结果被丢弃时缓存已满。

从Python 3.2开始,标准库附带了一个装饰器实现:@functools.lru_cache()

from functools import lru_cache

@lru_cache(16)  # cache 16 different `x` return values
def costly_function_a(x):
    print("costly_function_a has been called.")
    return x #Dummy operation.

@lru_cache(32)  # cache 32 different `x` return values
def costly_function_b(x):
    print("costly_function_b has been called.")
    return 5.*x #Dummy operation.

早期版本的backport is available,或选择一个可以处理PyPI上可用的LRU缓存的其他可用库。

如果您只需要缓存一个最近的项目,请创建自己的装饰器:

from functools import wraps

def cache_most_recent(func):
    cache = [None, None]
    @wraps(func)
    def wrapper(*args, **kw):
        if (args, kw) == cache[0]:
            return cache[1]
        cache[0] = args, kw
        cache[1] = func(*args, **kw)
        return cache[1]
    return wrapper

@cache_most_recent
def costly_function_a(x):
    print("costly_function_a has been called.")
    return x #Dummy operation.

@cache_most_recent
def costly_function_b(x):
    print("costly_function_b has been called.")
    return 5.*x #Dummy operation.

这个更简单的装饰器比更具特色的functools.lru_cache()具有更少的开销。

答案 1 :(得分:1)

我接受@MartijnPieters的解决方案,因为这可能是为99%的人提供类似我的问题的正确方法。但是,在我的特殊情况下,我只需要一个“1的缓存”,所以花哨的@lru_cache(1)装饰器有点矫枉过正。我最后编写了自己的装饰器(感谢this非常棒的stackoverflow答案),我在下面提供。请注意,我是Python的新手,所以这段代码可能并不完美。

from functools import wraps

def last_cache(func):
    """A decorator caching the last value returned by a function.

    If the decorated function is called twice (or more) in a row with exactly
    the same parameters, then this decorator will return a cached value of the
    decorated function's last output instead of calling it again. This may
    speed up execution if the decorated function is costly to call.

    The decorated function must respect the following conditions:
    1.  Repeated calls return the same value if the same parameters are used.
    2.  The function's only "task" is to return a value.

    """
    _first_call = [True]
    _last_args = [None]
    _last_kwargs = [None]
    _last_value = [None]
    @wraps(func)
    def _last_cache_wrapper(*args, **kwargs):
        if _first_call[0] or (args!=_last_args[0]) or (kwargs!=_last_kwargs[0]):
            _first_call[0] = False
            _last_args[0] = args
            _last_kwargs[0] = kwargs
            _last_value[0] = func(*args, **kwargs)
        return _last_value[0]
    return _last_cache_wrapper