Python:回想一下依赖于新函数参数

时间:2017-05-19 18:57:28

标签: python-3.x caching memoization

我对缓存和概念的概念比较陌生。记忆化。我已经阅读了其他一些讨论&资源hereherehere,但我们无法很好地关注它们。

假设我在一个类中有两个成员函数。 (下面的简化示例。)假设第一个函数total在计算上很昂贵。第二个函数subtotal在计算上是简单的,除了它使用第一个函数的返回,因此在计算上也变得昂贵,因为它当前需要重新调用total来获取它返回结果。

我想缓存第一个函数的结果,并将其用作第二个函数的输入, if 输入ysubtotal共享输入{{1最近一次致电x。那就是:

  • 如果调用subtotal(),其中total等于中的y值 之前调用x,然后使用该缓存结果而不是
    重新呼叫total
  • 否则,只需使用total致电total()

示例:

x = y

3 个答案:

答案 0 :(得分:1)

使用Python3.2或更高版本,您可以使用functools.lru_cache。 如果您直接使用total装饰functools.lru_cache,则lru_cache会根据两个参数total的值缓存self的返回值和x。由于lru_cache的内部dict存储对self的引用,因此将@lru_cache直接应用于类方法会创建对self的循环引用,这会使类的实例不可解除引用(因此会导致内存泄漏)。

Here is a workaround允许您将lru_cache与类方法一起使用 - 它会根据除第一个self以外的所有参数来缓存结果,并使用weakref避免循环引用问题:

import functools
import weakref

def memoized_method(*lru_args, **lru_kwargs):
    """
    https://stackoverflow.com/a/33672499/190597 (orly)
    """
    def decorator(func):
        @functools.wraps(func)
        def wrapped_func(self, *args, **kwargs):
            # We're storing the wrapped method inside the instance. If we had
            # a strong reference to self the instance would never die.
            self_weak = weakref.ref(self)
            @functools.wraps(func)
            @functools.lru_cache(*lru_args, **lru_kwargs)
            def cached_method(*args, **kwargs):
                return func(self_weak(), *args, **kwargs)
            setattr(self, func.__name__, cached_method)
            return cached_method(*args, **kwargs)
        return wrapped_func
    return decorator


class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b

    @memoized_method()
    def total(self, x):
        print('Calling total (x={})'.format(x))
        return (self.a + self.b) * x


    def subtotal(self, y, z):
        return self.total(x=y) + z 

mobj = MyObject(1,2)
mobj.subtotal(10, 20)
mobj.subtotal(10, 30)

打印

Calling total (x=10)

只有一次。

或者,这是使用dict滚动自己的缓存的方法:

class MyObject(object):

    def __init__(self, a, b):
        self.a, self.b = a, b
        self._total = dict()

    def total(self, x):
        print('Calling total (x={})'.format(x))
        self._total[x] = t = (self.a + self.b) * x
        return t

    def subtotal(self, y, z):
        t = self._total[y] if y in self._total else self.total(y)
        return t + z 

mobj = MyObject(1,2)
mobj.subtotal(10, 20)
mobj.subtotal(10, 30)

这个基于dict的缓存的lru_cache的一个优点是lru_cache 是线程安全的。 lru_cache也有一个maxsize参数可以提供帮助 防止内存使用量不受限制地增长(例如,由于 长时间运行的流程多次使用total的不同值调用x

答案 1 :(得分:1)

感谢大家的回复,只是阅读它们,看看幕后发生了什么是有帮助的。正如@Tadhg McDonald-Jensen所说,我似乎并不需要比@functools.lru_cache更多的东西。 (我在Python 3.5中。)关于@ unutbu的评论,我没有因使用@lru_cache装饰total()而收到错误。让我纠正我自己的例子,我会在这里为其他初学者保留这个:

from functools import lru_cache
from datetime import datetime as dt

class MyObject(object):
    def __init__(self, a, b):
        self.a, self.b = a, b

    @lru_cache(maxsize=None)
    def total(self, x):        
        lst = []
        for i in range(int(1e7)):
            val = self.a + self.b + x    # time-expensive loop
            lst.append(val)
        return np.array(lst)     

    def subtotal(self, y, z):
        return self.total(x=y) + z       # if y==x from a previous call of
                                         # total(), used cached result.

myobj = MyObject(1, 2)

# Call total() with x=20
a = dt.now()
myobj.total(x=20)
b = dt.now()
c = (b - a).total_seconds()

# Call subtotal() with y=21
a2 = dt.now()
myobj.subtotal(y=21, z=1)
b2 = dt.now()
c2 = (b2 - a2).total_seconds()

# Call subtotal() with y=20 - should take substantially less time
# with x=20 used in previous call of total().
a3 = dt.now()
myobj.subtotal(y=20, z=1)
b3 = dt.now()
c3 = (b3 - a3).total_seconds()

print('c: {}, c2: {}, c3: {}'.format(c, c2, c3))
c: 2.469753, c2: 2.355764, c3: 0.016998

答案 2 :(得分:0)

在这种情况下,我会做一些简单的事情,也许不是最优雅的方式,但可以解决问题:

class MyObject(object):
    param_values = {}
    def __init__(self, a, b):
        self.a, self.b = a, b

    def total(self, x):
        if x not in MyObject.param_values:
          MyObject.param_values[x] = (self.a + self.b) * x
          print(str(x) + " was never called before")
        return MyObject.param_values[x]

    def subtotal(self, y, z):
        if y in MyObject.param_values:
          return MyObject.param_values[y] + z
        else:
          return self.total(y) + z