Haskell动态编程的memoization工具性能差

时间:2015-05-22 05:09:44

标签: performance haskell memoization

问题从代码强制的算法问题开始,这里是链接Writing Code。 问题是一个基本的DP问题,可以在Solution Report找到解决方案。

我尝试使用Haksell解决它,我的解决方案是通过memoization递归,时间复杂度也是O(n^3)。这是我的代码:

{-# OPTIONS_GHC -O2 -funbox-strict-fields #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, BangPatterns #-}

import Data.Functor
import Data.List

class Memo m a where
    memo :: m -> a

instance Memo m m where
    memo = id

instance Memo m a => Memo (Int -> m) [a] where
    memo f = map (memo.f) [0..]

solve :: Int -> Int -> Int -> Int -> [Int] -> Int
solve n m b mod arr = modsum.map (go n m) $ [0 .. b]
    where mods = (`rem` mod)
          modsum = foldl1' (\a b -> mods (a + b))
          go :: Int -> Int -> Int -> Int
          go _ 0 0 = 1
          go _ 0 _ = 0
          go 0 _ _ = 0
          go i j k = mods $ a + b
            where !id = i - 1
                  !ai = arr !! id
                  !a = (if k >= ai then memo_go !! i !! (j - 1) !! (k - ai) else 0)
                  !b = memo_go !! id !! j !! k
          memo_go = memo go


main = do [n, m, b, mod] <- map read.words <$> getLine
          arr <- map read.words <$> getLine
          print $ solve n m b mod arr          

使用RTS的该程序的效果报告是:

    Fri May 22 12:25 2015 Time and Allocation Profiling Report  (Final)

       test.exe +RTS -p -RTS

    total time  =        5.75 secs   (5751 ticks @ 1000 us, 1 processor)
    total alloc = 189,957,112 bytes  (excludes profiling overheads)

COST CENTRE MODULE  %time %alloc

solve.go    Main     97.2    8.5
memo        Main      2.3   82.8
solve.mods  Main      0.5    8.5


                                                                  individual     inherited
COST CENTRE          MODULE                     no.     entries  %time %alloc   %time %alloc

MAIN                 MAIN                        52           0    0.0    0.0   100.0  100.0
 main                Main                       105           0    0.0    0.3   100.0  100.0
  solve.modsum       Main                       108           1    0.0    0.0     0.0    0.0
  solve              Main                       106           1    0.0    0.0   100.0   99.7
   solve.memo_go     Main                       114           1    0.0    0.0    99.9   99.7
    memo             Main                       115       10301    2.3   82.8    99.9   99.7
     solve.go        Main                       118     1024999   97.2    8.5    97.6   16.9
      solve.mods     Main                       119           0    0.5    8.5     0.5    8.5
   solve.go          Main                       113         101    0.0    0.0     0.0    0.0
    solve.mods       Main                       120           0    0.0    0.0     0.0    0.0
   solve.mods        Main                       111           1    0.0    0.0     0.0    0.0
   solve.modsum      Main                       109           0    0.0    0.0     0.0    0.0
    solve.modsum.\   Main                       110         100    0.0    0.0     0.0    0.0
     solve.mods      Main                       112           0    0.0    0.0     0.0    0.0
 CAF                 GHC.IO.Encoding.CodePage    91           0    0.0    0.0     0.0    0.0
 CAF                 GHC.IO.Encoding             84           0    0.0    0.0     0.0    0.0
 CAF                 Text.Read.Lex               81           0    0.0    0.0     0.0    0.0
 CAF                 GHC.IO.Handle.FD            72           0    0.0    0.0     0.0    0.0
 CAF:main1           Main                        68           0    0.0    0.0     0.0    0.0
  main               Main                       104           1    0.0    0.0     0.0    0.0
 CAF:lvl2_r3ul       Main                        64           0    0.0    0.0     0.0    0.0
 CAF:$fMemo(->)[]1   Main                        61           0    0.0    0.0     0.0    0.0
  memo               Main                       116           0    0.0    0.0     0.0    0.0
 CAF:$fMemomm_$cmemo Main                        59           0    0.0    0.0     0.0    0.0
  memo               Main                       117           1    0.0    0.0     0.0    0.0

可以预见,瓶颈是memo_go。 memo_go是一个三维数组,其中每个元素都是thunk go i j k,它只会被评估一次。但是,这段代码得到了法官的TLE。

我徘徊如何提高此代码的性能?

Updata公司 我将List更改为Array并获得了一个MLE,所以我使用滚动数组来减少内存,这是我的代码:

{-# OPTIONS_GHC -O2 -funbox-strict-fields #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, BangPatterns #-}

import Data.Functor
import Data.List (foldl', foldl1')
import Data.Array
--import qualified  Data.Vector as V

size = 510

class Memo m a where
    memo :: m -> a

instance Memo m m where
    memo = id

instance Memo m a => Memo (Int -> m) (Array Int a) where
    memo f = fmap (memo.f) $ listArray (0, size) [0..size]

solve :: Int -> Int -> Int -> Int -> [Int] -> Int
solve n m b mod lis = modsum.map (res ! m !) $ [0 .. b]
    where arr = listArray (0, n - 1) lis
          mods = (`rem` mod)
          modsum = foldl1' (\a b -> mods (a + b))
          init :: Int -> Int -> Int
          init 0 0 = 1
          init _ _ = 0
          arr_init = memo init
          dp n = foldl' build arr_init [1..n]
            where build ar idx = memo_go'
                    where !ai = arr ! (idx - 1)
                          memo_go' = memo go'
                          go' i j = mods $ a + b
                            where !a = ar ! i ! j
                                  !b = if i > 0 && j >= ai then memo_go' ! (i - 1) ! (j - ai) else 0
          res = dp n

main = do [n, m, b, mod] <- map read.words <$> getLine
          arr <- map read.words <$> getLine
          print $ solve n m b mod arr          

仍然是MLE,也许我应该使用可变数据结构:(

0 个答案:

没有答案