如何编写通用memoize函数?

时间:2008-09-24 20:48:41

标签: optimization recursion lua closures memoization

我正在编写一个函数来查找triangle numbers,并且以递归方式编写它的自然方式:

function triangle (x)
   if x == 0 then return 0 end
   return x+triangle(x-1)
end

但是尝试计算前100,000个三角形数字会在一段时间后出现堆栈溢出而失败。这是memoize的理想函数,但是我想要一个能够记住我传递给它的任何函数的解决方案。

16 个答案:

答案 0 :(得分:7)

Mathematica有一种特别灵活的记忆方式,依赖哈希和函数调用使用相同语法的事实:

triangle[0] = 0;
triangle[x_] := triangle[x] = x + triangle[x-1]

就是这样。它的工作原理是因为模式匹配函数调用的规则总是在更一般的定义之前使用更具体的定义。

当然,正如已经指出的那样,这个例子有一个封闭形式的解决方案:triangle[x_] := x*(x+1)/2。 Fibonacci数字是添加memoization如何大幅加速的典型例子:

fib[0] = 1;
fib[1] = 1;
fib[n_] := fib[n] = fib[n-1] + fib[n-2]

虽然它也有一个封闭形式的等价物,尽管比较混乱:http://mathworld.wolfram.com/FibonacciNumber.html

我不同意建议这不适合进行回忆的人,因为你可以“只使用循环”。记忆点是任何重复函数调用都是O(1)时间。这比O(n)要好很多。实际上,您甚至可以编写一个场景,其中memoized实现的性能优于封闭式实现!

答案 1 :(得分:6)

你也问你原来的问题是错误的问题;)

对于这种情况,这是一种更好的方式:

triangle(n)= n *(n - 1)/ 2

此外,假设公式没有这样一个简洁的解决方案,那么备忘录在这里仍然是一个糟糕的方法。在这种情况下,你最好只编写一个简单的循环。有关更全面的讨论,请参阅this answer

答案 2 :(得分:5)

我打赌这样的事情应该适用于Lua中的变量参数列表:

local function varg_tostring(...)
    local s = select(1, ...)
    for n = 2, select('#', ...) do
        s = s..","..select(n,...)
    end
    return s
end

local function memoize(f)
    local cache = {}
    return function (...)
        local al = varg_tostring(...)
        if cache[al] then
            return cache[al]
        else
            local y = f(...)
            cache[al] = y
            return y
        end
    end
end

你可能也可以使用带有__tostring的元表来做一些聪明的事情,这样参数列表就可以用tostring()转换。哦,可能性。

答案 3 :(得分:5)

在C#3.0中 - 对于递归函数,您可以执行以下操作:

public static class Helpers
{
    public static Func<A, R> Memoize<A, R>(this Func<A, Func<A,R>,  R> f)
    {
        var map = new Dictionary<A, R>();
        Func<A, R> self = null;
        self = (a) =>
        {
            R value;
            if (map.TryGetValue(a, out value))
                return value;
            value = f(a, self);
            map.Add(a, value);
            return value;
        };
        return self;
    }        
}

然后你可以创建一个这样的memoized Fibonacci函数:

var memoized_fib = Helpers.Memoize<int, int>((n,fib) => n > 1 ? fib(n - 1) + fib(n - 2) : n);
Console.WriteLine(memoized_fib(40));

答案 4 :(得分:4)

在Scala中(未经测试):

def memoize[A, B](f: (A)=>B) = {
  var cache = Map[A, B]()

  { x: A =>
    if (cache contains x) cache(x) else {
      val back = f(x)
      cache += (x -> back)

      back
    }
  }
}

请注意,这仅适用于arity 1的功能,但通过currying可以使其工作。更微妙的问题是任何函数memoize(f) != memoize(f)的{​​{1}}。解决这个问题的一个非常偷偷摸摸的方法就是如下:

f

我不认为这会编译,但它确实说明了这个想法。

答案 5 :(得分:4)

更新:评论者指出,memoization是优化递归的好方法。不可否认,我以前没有考虑过这个问题,因为我通常使用一种语言(C#),在这种语言中,广义的memoization并不是那么容易构建的。请记下下面的那篇文章。

我认为这个问题Luke likely has the most appropriate solution,但是memoization通常不是任何堆栈溢出问题的解决方案。

堆栈溢出通常是由递归比平台可以处理的更深层引起的。语言有时支持“tail recursion”,它重用当前调用的上下文,而不是为递归调用创建新的上下文。但是很多主流语言/平台都不支持这一点。例如,C#对尾递归没有固有的支持。 64位版本的.NET JITter可以将其作为IL级别的优化应用,如果您需要支持32位平台,这几乎毫无用处。

如果您的语言不支持尾递归,那么避免堆栈溢出的最佳选择是转换为显式循环(更不优雅,但有时​​是必要的),或者找到非迭代算法,例如Luke提供的这个问题。

答案 6 :(得分:3)

function memoize (f)
   local cache = {}
   return function (x)
             if cache[x] then
                return cache[x]
             else
                local y = f(x)
                cache[x] = y
                return y
             end
          end
end

triangle = memoize(triangle);

请注意,为避免堆栈溢出,仍需要三角形播种。

答案 7 :(得分:2)

这里的东西可以在不将参数转换为字符串的情况下工作。 唯一需要注意的是它无法处理nil参数。但是,已接受的解决方案无法将值nil与字符串"nil"区分开来,所以这可能没问题。

local function m(f)
  local t = { }
  local function mf(x, ...) -- memoized f
    assert(x ~= nil, 'nil passed to memoized function')
    if select('#', ...) > 0 then
      t[x] = t[x] or m(function(...) return f(x, ...) end)
      return t[x](...)
    else
      t[x] = t[x] or f(x)
      assert(t[x] ~= nil, 'memoized function returns nil')
      return t[x]
    end
  end
  return mf
end

答案 8 :(得分:2)

我已经受到这个问题的启发,在Lua中实现(又一个)灵活的memoize功能。

https://github.com/kikito/memoize.lua

主要优势:

  • 接受可变数量的参数
  • 不使用tostring;相反,它使用参数遍历树结构来组织缓存。
  • 使用返回multiple values
  • 的函数可以正常工作

将代码粘贴在此处作为参考:

local globalCache = {}

local function getFromCache(cache, args)
  local node = cache
  for i=1, #args do
    if not node.children then return {} end
    node = node.children[args[i]]
    if not node then return {} end
  end
  return node.results
end

local function insertInCache(cache, args, results)
  local arg
  local node = cache
  for i=1, #args do
    arg = args[i]
    node.children = node.children or {}
    node.children[arg] = node.children[arg] or {}
    node = node.children[arg]
  end
  node.results = results
end


-- public function

local function memoize(f)
  globalCache[f] = { results = {} }
  return function (...)
    local results = getFromCache( globalCache[f], {...} )

    if #results == 0 then
      results = { f(...) }
      insertInCache(globalCache[f], {...}, results)
    end

    return unpack(results)
  end
end

return memoize

答案 9 :(得分:1)

这是一个通用的C#3.0实现,如果有帮助的话:

public static class Memoization
{
    public static Func<T, TResult> Memoize<T, TResult>(this Func<T, TResult> function)
    {
        var cache = new Dictionary<T, TResult>();
        var nullCache = default(TResult);
        var isNullCacheSet = false;
        return  parameter =>
                {
                    TResult value;

                    if (parameter == null && isNullCacheSet)
                    {
                        return nullCache;
                    }

                    if (parameter == null)
                    {
                        nullCache = function(parameter);
                        isNullCacheSet = true;
                        return nullCache;
                    }

                    if (cache.TryGetValue(parameter, out value))
                    {
                        return value;
                    }

                    value = function(parameter);
                    cache.Add(parameter, value);
                    return value;
                };
    }
}

(引自french blog article

答案 10 :(得分:1)

在发布不同语言的备忘录的过程中,我想用非语言改变的C ++示例回复@ onebyone.livejournal.com。

首先,单个arg函数的memoizer:

template <class Result, class Arg, class ResultStore = std::map<Arg, Result> >
class memoizer1{
public:
    template <class F>
    const Result& operator()(F f, const Arg& a){
        typename ResultStore::const_iterator it = memo_.find(a);
        if(it == memo_.end()) {
            it = memo_.insert(make_pair(a, f(a))).first;
        }
        return it->second;
    }
private:
    ResultStore memo_;
};

只需创建一个memoizer实例,输入你的函数和参数。请确保不要在两个不同的函数之间共享相同的备忘录(但您可以在同一函数的不同实现之间共享它)。

接下来,一个驱动程序功能,以及一个实现。只有驱动功能需要公开     int fib(int); //司机     int fib_(int); //实施

实现:

int fib_(int n){
    ++total_ops;
    if(n == 0 || n == 1) 
        return 1;
    else
        return fib(n-1) + fib(n-2);
}

和司机,要记住

int fib(int n) {
    static memoizer1<int,int> memo;
    return memo(fib_, n);
}
在codepad.org上的

Permalink showing output。测量呼叫次数以验证正确性。 (在这里插入单元测试......)

这只会记住一个输入功能。推广多个参数或不同的参数作为读者的练习。

答案 11 :(得分:1)

在Perl中,通用记忆很容易获得。 Memoize模块是perl核心的一部分,具有高可靠性,灵活性和易用性。

该手册的示例:

# This is the documentation for Memoize 1.01
use Memoize;
memoize('slow_function');
slow_function(arguments);    # Is faster than it was before

您可以在运行时添加,删除和自定义函数的memoization!您可以为自定义纪念计算提供回调。

Memoize.pm甚至还有使memento缓存持久化的工具,因此不需要在每次调用程序时重新填充它!

以下是文档:http://perldoc.perl.org/5.8.8/Memoize.html

答案 12 :(得分:1)

有关通用Scala解决方案,请参阅this blog post,最多4个参数。

答案 13 :(得分:0)

扩展这个想法,也可以使用两个输入参数来记忆函数:

function memoize2 (f)
   local cache = {}
   return function (x, y)
             if cache[x..','..y] then
                return cache[x..','..y]
             else
                local z = f(x,y)
                cache[x..','..y] = z
                return z
             end
          end
end

请注意,参数顺序在缓存算法中很重要,因此如果参数顺序在要记忆的函数中无关紧要,那么在检查缓存之前对参数进行排序会增加获得缓存命中的几率。

但重要的是要注意某些功能无法获得有利可图的记忆。我写了 memoize2 ,看看是否可以加快寻找最大公约数的递归Euclidean algorithm

function gcd (a, b) 
   if b == 0 then return a end
   return gcd(b, a%b)
end

事实证明, gcd 对备忘录的反应不佳。它的计算远比缓存算法便宜。对于大数字,它会很快终止。过了一会儿,缓存变得非常大。这个算法可能会尽可能快。

答案 14 :(得分:0)

没有必要进行递归。第n个三角形数是n(n-1)/ 2,所以......

public int triangle(final int n){
   return n * (n - 1) / 2;
}

答案 15 :(得分:0)

请不要递言。使用x *(x + 1)/ 2公式或简单地迭代值并随时记忆。

int[] memo = new int[n+1];
int sum = 0;
for(int i = 0; i <= n; ++i)
{
  sum+=i;
  memo[i] = sum;
}
return memo[n];