在Haskell

时间:2016-09-08 15:33:15

标签: haskell recursion

我正在解决this programming problem,而且我超出了当前解决方案的时间限制。我相信我的问题的解决方案是memoization。但是,我不明白记录的备忘录解决方案here

这是我当前解决方案中的主要功能。

maxCuts :: Int -> Int -> Int -> Int -> Int
maxCuts n a b c  
    | n == 0    = 0
    | n < 0     = -10000
    | otherwise = max (max amax bmax) cmax
    where 
        amax = 1 + maxCuts (n - a) a b c
        bmax = 1 + maxCuts (n - b) a b c
        cmax = 1 + maxCuts (n - c) a b c

如果b和c相对于n较小,则此函数运行时间过长。我只想复制他们用于阶乘函数的解决方案,但该函数只需要一个参数。我有四个参数,但我只想在第一个参数n上键入memoiziation。请注意,a bc在递归调用中不会更改。

2 个答案:

答案 0 :(得分:2)

重写你的函数定义:

 maxCuts :: Int -> Int -> Int -> Int -> Int
 maxCuts n a b c = maxCuts' n where
     maxCuts' n
          | n == 0    = 0
          | n < 0     = -10000
          | otherwise = max (max amax bmax) cmax
            where 
               amax = 1 + maxCuts' (n - a)  
               bmax = 1 + maxCuts' (n - b) 
               cmax = 1 + maxCuts' (n - c)

现在你有一个可以记忆的单参数函数。

答案 1 :(得分:0)

除此之外:你的算法不只是计算类似div n (minimum [a,b,c])的东西吗?

正如您所指出的,参数a,b和c不会改变,因此首先重写函数以将参数n放在最后。

如果您决定使用列表来记忆它所需的功能值 稍微注意确保GHC将保存映射列表:

import Debug.Trace

maxCuts' :: Int -> Int -> Int -> Int -> Int
maxCuts' a b c n = memoized_go n
  where
    memoized_go n
      | n < 0 = -10000
      | otherwise =  mapped_list !! n

    mapped_list = map go [0..]

    go n | trace msg False = undefined
      where msg = "go called for " ++ show n
    go 0 = 0
    go n = maximum [amax, bmax, cmax]
      where
        amax = 1 + memoized_go (n-a)
        bmax = 1 + memoized_go (n-b)
        cmax = 1 + memoized_go (n-c)

test1 = print $ maxCuts' 1 2 3 10

请注意定义的循环依赖关系:memoized_go取决于mapped_listgo取决于memozied_go,取决于n < 0

由于列表只允许非负索引,因此必须使用trace个案例 以单独的警卫模式处理。

go次来电显示n每{... 1}}值只调用一次mapped_list。 例如,考虑尝试在不定义maxCuts2 :: Int -> Int -> Int -> Int -> Int maxCuts2 a b c n = memoized_go n where memoized_go n | n < 0 = -10000 | otherwise = (map go [0..]) !! n -- mapped_list = map go [0..] go n | trace msg False = undefined where msg = "go called for " ++ show n go 0 = 0 go n = maximum [amax, bmax, cmax] where amax = 1 + memoized_go (n-a) bmax = 1 + memoized_go (n-b) cmax = 1 + memoized_go (n-c) test2 = print $ maxCuts2 1 2 3 11 的情况下执行此操作:

test2

正在运行go表示n多次被调用amax的相同值。

<强>更新

为了避免创建大量未评估的thunk,我会使用BangPatterns 对于bmaxcmax{-# LANGUAGE BangPatterns #-} maxCuts' ... = ... where !amax = 1 + ... !bmax = 1 + ... !cmax = 1 + ...

def test_action(self):
    sc = SuperCool()
    assert sc.action(1) == 1