Haskell:代码运行速度太慢

时间:2016-09-20 18:27:29

标签: performance haskell recursion motzkin-numbers

我有一个代码,用于计算Motzkin编号:

module Main where

    -- Program execution begins here
    main :: IO ()
    main = interact (unlines . (map show) . map wave . (map read) . words)

    -- Compute Motzkin number
    wave :: Integer -> Integer
    wave 0 = 1
    wave 1 = 1
    wave n = ((3 * n - 3) * wave (n - 2) + (2 * n + 1) * wave (n - 1)) `div` (n + 2)

但即使是简单数字30的输出也需要一段时间才能返回。

任何优化想法??

4 个答案:

答案 0 :(得分:10)

计算Fibonacci数字有一个标准技巧,可以很容易地适应你的问题。 Fibonacci数字的天真定义是:

fibFunction :: Int -> Integer
fibFunction 0 = 1
fibFunction 1 = 1
fibFunction n = fibFunction (n-2) + fibFunction (n-1)

然而,这是非常昂贵的:因为递归的所有叶子都是1,如果是fib x = y,那么我们必须执行y递归调用!由于斐波纳契数以指数方式增长,这是一个糟糕的事态。但是通过动态编程,我们可以共享两个递归调用所需的计算。令人愉悦的单行内容如下:

fibList :: [Integer]
fibList = 1 : 1 : zipWith (+) fibList (tail fibList)

起初看起来有点令人费解;这里fibList的{​​{1}}参数作为前两个索引的递归,而zipWith参数作为前一个索引的递归,这给了我们tail fibListfib (n-2)值。开头的两个fib (n-1)当然是基本情况。有other good questions here on SO更详细地解释了这种技术,你应该研究这些代码和那些答案,直到你感觉它是如何工作的以及为什么它非常快。

如有必要,可以使用1从中恢复Int -> Integer类型签名。

让我们尝试将此技术应用于您的功能。与计算Fibonacci数一样,您需要前一个和倒数第二个值;另外还需要当前的指数。这可以通过在(!!)的调用中加入[2..]来完成。以下是它的外观:

zipWith

和以前一样,可以使用waves :: [Integer] waves = 1 : 1 : zipWith3 thisWave [2..] waves (tail waves) where thisWave n back2 back1 = ((3 * n - 3) * back2 + (2 * n + 1) * back1) `div` (n + 2) (!!)恢复功能版本(如果真的需要genericIndex个索引)。我们可以确认它在ghci中计算相同的函数(但更快,并且使用更少的内存):

Integer

答案 1 :(得分:6)

当n = 30时,您需要计算wave 29wave 28,而这需要计算wave 28wave 27两次和wave 26等等,这很快就会达到数十亿。

您可以使用与计算斐波那契数字相同的技巧:

wave 0 = 1
wave 1 = 1
wave n = helper 1 1 2
    where
       helper x y k | k <n      = helper y z (k+1)
                    | otherwise = z
                    where z = ((3*k-3) * x + (2*k+1) * y) `div` (k+2)

这是以线性时间运行的,并且帮助器为kwave (k-2)的每个wave (k-1)值准备好了。

答案 2 :(得分:1)

这是一个记忆版

wave = ((1:1:map waveCalc [2..]) !!)
    where waveCalc n = ( (2*n+1)*wave (n-1) + (3*n-3)*wave (n-2) ) `div` (n+2)

答案 3 :(得分:0)

感谢大家的回复。根据我对Memoization的理解,我将代码重写为:

mwave :: Int -> Int
mwave = (map wave [0..] !!)
  where wave 0 = 1
        wave 1 = 1
        wave n = ((3 * n - 3) * mwave (n - 2) + (2 * n + 1) * mwave (n - 1)) `div` (n + 2)

digits :: Int -> Int
digits n = (mwave n) `mod` 10^(100::Int)

有关如何输出模块10 ^ 100的答案的想法