高效的通用Python memoize

时间:2012-12-28 18:47:57

标签: python memoization

我有一个通用的Python memoizer:

cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, str(args))
        result = cache.get(key, None)
        if result is None:
            result = f(*args)
            cache[key] = result
        return result

    return decorated

它有效,但我对它不满意,因为有时效率不高。最近,我使用了一个将列表作为参数的函数,并且显然使用整个列表创建键会减慢所有内容。最好的方法是什么? (即,有效地计算密钥,无论args如何,无论它们多长还是复杂)

我想这个问题实际上是关于如何从args和泛型memoizer的函数有效地生成密钥 - 我在一个程序中观察到,糟糕的密钥(生成成本太高)对运行时产生了重大影响。我的编程用'str(args)'拍摄了45秒,但我可以用手工制作的键将其减少到3秒。不幸的是,手工制作的密钥是特定于这个编程,但我想要一个快速的记事本,我不必每次都为缓存推出特定的,手工制作的密钥。

2 个答案:

答案 0 :(得分:6)

首先,如果您非常确定此处O(N)散列是合理且必要的,并且您只想使用比hash(str(x))更快的算法加快速度,请尝试以下操作:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        result ^= hash(element)
    return result

当然这对于可能很深的序列不起作用,但是有一个显而易见的方法:

def hash_seq(iterable):
    result = hash(type(iterable))
    for element in iterable:
        try:
            result ^= hash(element)
        except TypeError:
            result ^= hash_seq(element)
    return result

我不认为这是一个足够好的哈希算法,因为它会为同一个列表的不同排列返回相同的值。但我很确定没有足够好的哈希算法会快得多。至少如果它是用C或Cython编写的,如果这是你要去的方向,你最终可能会想做。

另外,值得注意的是,在str(或marshal)不会的情况下,这是正确的 - 例如,如果您的list可能有一些可变元素{ {1}}涉及其repr而不是其值。但是,它在所有情况下仍然不正确。特别是,它假定“迭代相同的元素”对于任何可迭代类型意味着“相等”,这显然不能保证是真的。假阴性不是很大,但误报是(例如,两个id s具有相同的键,但不同的值可能虚假地比较相等并共享备忘录。)

此外,它不使用额外的空间,而是使用相当大的乘数的O(N)。

无论如何,首先尝试这一点是值得的,然后才决定是否值得分析是否足够好和微调优化。

这是浅层实现的一个简单的Cython版本:

dict

从快速测试开始,纯Python实现非常慢(正如您所期望的那样,所有Python循环,与def test_cy_xor(iterable): cdef int result = hash(type(iterable)) cdef int h for element in iterable: h = hash(element) result ^= h return result str中的C循环相比),但是Cython版本轻松获胜:

marshal

只是在Cython中迭代序列并且什么也不做(实际上只是N调用 test_str( 3): 0.015475 test_marshal( 3): 0.008852 test_xor( 3): 0.016770 test_cy_xor( 3): 0.004613 test_str(10000): 8.633486 test_marshal(10000): 2.735319 test_xor(10000): 24.895457 test_cy_xor(10000): 0.716340 和一些引用计数,所以你不会在原生C中做得更好)是70%的同时为PyIter_Next。你可以通过要求一个实际的序列而不是一个可迭代来使它更快,更需要一个test_cy_xor,尽管它可能需要编写显式C而不是Cython才能获得好处。

无论如何,我们如何解决订购问题?显而易见的Python解决方案是散列list而不是(i, element),但所有这些元组操作都会使Cython版本减慢到12倍。标准解决方案是在每个xor之间乘以一些数字。但是,当你在它的时候,值得尝试让值很好地分散为短序列,小element元素和其他非常常见的边缘情况。选择正确的数字很棘手,所以......我只是从tuple借来的。这是完整的测试。

_hashtest.pyx:

int

hashtest.py:

cdef _test_xor(seq):
    cdef long result = 0x345678
    cdef long mult = 1000003
    cdef long h
    cdef long l = 0
    try:
        l = len(seq)
    except TypeError:
        # NOTE: This probably means very short non-len-able sequences
        # will not be spread as well as they should, but I'm not
        # sure what else to do.
        l = 100
    for element in seq:
        try:
            h = hash(element)
        except TypeError:
            h = _test_xor(element)
        result ^= h
        result *= mult
        mult += 82520 + l + l
    result += 97531
    return result

def test_xor(seq):
    return _test_xor(seq) ^ hash(type(seq))

输出:

import marshal
import random
import timeit
import pyximport
pyximport.install()
import _hashtest

def test_str(seq):
    return hash(str(seq))

def test_marshal(seq):
    return hash(marshal.dumps(seq))

def test_cy_xor(seq):
    return _hashtest.test_xor(seq)

# This one is so slow that I don't bother to test it...
def test_xor(seq):
    result = hash(type(seq))
    for i, element in enumerate(seq):
        try:
            result ^= hash((i, element))
        except TypeError:
            result ^= hash(i, hash_seq(element))
    return result

smalltest = [1,2,3]
bigtest = [random.randint(10000, 20000) for _ in range(10000)]

def run():
    for seq in smalltest, bigtest:
        for f in test_str, test_marshal, test_cy_xor:
            print('%16s(%5d): %9f' % (f.func_name, len(seq),
                                      timeit.timeit(lambda: f(seq), number=10000)))

if __name__ == '__main__':
    run()

以下是一些提高此速度的潜在方法:

  • 如果你有很多深层序列,而不是在 test_str( 3): 0.014489 test_marshal( 3): 0.008746 test_cy_xor( 3): 0.004686 test_str(10000): 8.563252 test_marshal(10000): 2.744564 test_cy_xor(10000): 0.904398 周围使用try,请致电hash并检查-1。
  • 如果你知道你有一个序列(或者更好,特别是PyObject_Hash),而不仅仅是一个可迭代的,list(或PySequence_ITEM)可能会更快而不是上面隐式使用的PyList_GET_ITEM

在任何一种情况下,一旦你开始调用C API调用,通常更容易删除Cython并在C中编写函数。(你仍然可以使用Cython在该C函数周围编写一个简单的包装器,而不是手动编码扩展模块。)此时,只需直接借用PyIter_Next代码,而不是重新实现相同的算法。

如果您正在寻找一种避免tuplehash的方法,那是不可能的。如果你看看tuple.__hash__frozenset.__hash__ImmutableSet.__hash__是如何工作的(最后一个是纯Python而且非常易读),那么它们都会O(N)。但是,他们都会缓存哈希值。因此,如果您经常对相同的 O(N)(而不是非相同但相等的)进行哈希处理,则它会接近恒定时间。 (tuple,其中O(N/M)是您使用M调用的次数。)

如果你可以假设你的tuple个对象在调用之间永远不变,你可以显然做同样的事情,例如,使用list映射dictid外部缓存。但总的来说,这显然不是一个合理的假设。 (如果您的hash对象永远不会变异,那么只需切换到list个对象就更容易了,而不必担心所有这些复杂性。)

但是你可以将一个tuple对象包装在一个子类中,该子类添加一个缓存的哈希值成员(或槽),并在缓存调用时调用缓存(list,{{1 },append等)。然后你的__setitem__可以检查一下。

最终结果与__delitem__ s:分摊hash_seq具有相同的正确性和效果,但tuple O(N/M)是您与每个tuple调用的次数相同的M,而对于tuple,它是您使用每个相同的list进行调用而不会在其间发生变异的次数。

答案 1 :(得分:3)

您可以尝试以下几点:

使用marshal.dumps而不是str可能会稍快一点(至少在我的机器上):

>>> timeit.timeit("marshal.dumps([1,2,3])","import marshal", number=10000)
0.008287056301007567
>>> timeit.timeit("str([1,2,3])",number=10000)
0.01709315717356219

另外,如果你的函数计算成本很高,并且可能自己都返回None,那么你的memoizing函数每次都会重新计算它们(我可能会到达这里,但不知道更多我只能猜测)。 结合这两件事给出了:

import marshal
cache = {}

def memoize(f): 
    """Memoize any function."""

    def decorated(*args):
        key = (f, marshal.dumps(args))
        if key in cache:
            return cache[key]

        cache[key] = f(*args)
        return cache[key]

    return decorated