运行recusive编译代码时堆栈空间溢出错误。微调算法,还是提供更多资源?

时间:2012-10-05 07:26:30

标签: haskell

我正在研究项目Euler #14,并且有一个解决方案来获得答案,但是当我尝试运行代码时,我遇到了堆栈空间溢出错误。该算法在交互式GHCI中工作正常(在低数字上),但是当我向它抛出一个非常大的数字并尝试编译它时,它不会工作。

以下是交互式ghci中它的作用的粗略概念。在我的电脑上计算“回答50000”大约需要10秒钟。

编辑:让GHCI运行问题几分钟后,它会吐出正确答案。

*Euler System.IO> answer 1000000
    (525,837799)

但是,当编译程序本机运行时,这并没有解决堆栈溢出错误。

*Euler System.IO> answer 10
    (20,9)
*Euler System.IO> answer 100
    (119,97)
*Euler System.IO> answer 1000
    (179,871)
*Euler System.IO> answer 10000
    (262,6171)
*Euler System.IO> answer 50000
    (324,35655)

我该怎么做才能得到“回答1000000”的答案?我想我的算法需要稍微调整一下,但我不知道如何去做。

这是我到目前为止编写的代码:

module Main
    where

import System.IO
import Control.Monad

main = print (answer 1000000)

-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total 
-- length of the chain
count' n = (cSeq n, n)
cSeq n = length $ game n

-- Find the maximum chain value of the game
answer n = maximum $ map count' [1..n]

-- Working game. 
-- game 13 = [13,40,20,10,5,16,8,4,2,1]
game n = n : play n
play x
    | x <= 0 = []                               -- is negative or 0
    | x == 1 = []                               -- is 1
    | even x = doEven x : play ((doEven x))     -- even
    | otherwise = doOdd x : play ((doOdd x))    -- odd
  where doOdd x = (3 * x) + 1
        doEven  x = (x `div` 2)

3 个答案:

答案 0 :(得分:4)

这里的问题是maximum太懒了。它不是跟踪最大的元素,而是构建了一个巨大的max thunk树。这是因为maximum是根据foldl定义的,因此评估如下:

maximum [1, 2, 3, 4, 5]
foldl max 1 [2, 3, 4, 5]
foldl max (max 1 2) [3, 4, 5]
foldl max (max (max 1 2) 3) [4, 5]
foldl max (max (max (max 1 2) 3) 4) [5]
foldl max (max (max (max (max 1 2) 3) 4) 5) []
max (max (max (max 1 2) 3) 4) 5  -- this expression will be huge for large lists

尝试评估太多这些嵌套的max调用会导致堆栈溢出。

解决方案是强制它使用严格版本foldl'(或者,在本例中为其表兄foldl1')进行评估。这可以防止max在每个步骤中减少它们来构建:

foldl1' max [1, 2, 3, 4, 5]
foldl' max 1 [2, 3, 4, 5]
foldl' max 2 [3, 4, 5]
foldl' max 3 [4, 5]
foldl' max 4 [5]
foldl' max 5 []
5

如果你使用-O2进行编译,GHC通常可以自行解决这些问题,其中({等等}}对你的程序进行严格的分析。但是,我认为编写不需要依赖优化工作的程序是一种很好的做法。

注意:修复此问题后,生成的程序仍然很慢。您可能希望使用memoization来解决此问题。

答案 1 :(得分:4)

@hammar已经pointed out maximum太懒了,以及如何解决这个问题(使用foldl1'严格版foldl1)。

但代码中存在进一步的低效率。

cSeq n = length $ game n

cSeqgame构建一个列表,只计算其长度。不幸的是,length不是“好消费者”,因此中间列表的构造不会被融合。这是相当多的不必要的分配和成本时间。消除这些列表

cSeq n = coll (1 :: Int) n
  where
    coll acc 1 = acc
    coll acc m
      | even m    = coll (acc + 1) (m `div` 2)
      | otherwise = coll (acc + 1) (3*m+1)

将分配减少了65%,运行时间减少了约20%(仍然很慢)。接下来,您正在使用div,除正常分割外,还会执行签名检查。由于所涉及的所有数字都是正数,因此使用quot会加快速度(在这里不多,但稍后会变得很重要)。

接下来的重点是,因为你没有给出类型签名,数字的类型(除了使用length或表达式类型签名(1 :: Int)确定的位置)在我的重写中)是IntegerInteger上的操作比Int上的相应操作慢得多,因此如果可能,您应该使用Int(或Word)而不是Integer时速度很重要。如果您有64位GHC,Int就足以进行这些计算,那么使用div时运行时间减少一半,使用quot时运行时间减少约70%本机代码生成器,使用LLVM后端时,使用div时运行时间减少约70%,使用quot时运行时间减少约95%。

本机代码生成器和LLVM后端之间的区别主要是由于一些基本的低级优化。

evenodd已定义

even, odd       :: (Integral a) => a -> Bool
even n          =  n `rem` 2 == 0
odd             =  not . even
GHC.Real中的

。当类型为Int时,LLVM知道将除法替换为2用于通过按位和(n .&. 1 == 0)确定模数。本机代码生成器(尚未)执行许多这些低级优化。如果您手动执行此操作,NCG和LLVM后端生成的代码几乎完全相同。

当使用div时,NCG和LLVM都无法用短移位和加法序列替换除法,因此使用符号测试得到相对较慢的机器除法指令。使用quot,两者都可以为Int执行此操作,因此您可以获得更快的代码。

所有出现的数字都是正数的知识允许我们用简单的右移替换除2,而没有任何代码来纠正否定参数,这使得LLVM后端产生的代码加速了另外~33%,奇怪的是,它对NCG没有任何影响。

所以从最初的8秒加/减(使用NCG少一点,使用LLVM后端稍微多一点),我们已经去了

module Main (main)
    where

import Data.List
import Data.Bits

main = print (answer (1000000 :: Int))

-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total 
-- length of the chain
count' n = (cSeq n, n)
cSeq n = go (1 :: Int) n
  where
    go !acc 1 = acc
    go acc m
        | even' m   = go (acc+1) (m `shiftR` 1)
        | otherwise = go (acc+1) (3*m+1)

even' :: Int -> Bool
even' m = m .&. 1 == 0

-- Find the maximum chain value of the game
answer n = foldl1' max $ map count' [1..n]

使用NCG需要0.37秒,在我的设置上使用LLVM后端需要0.27秒。

运行时间略有改善,但通过手动递归替换foldl1' max可以大大减少分配,

answer n = go 1 1 2
  where
    go ml mi i
        | n < i     = (ml,mi)
        | l > ml    = go l i (i+1)
        | otherwise = go ml mi (i+1)
          where
            l = cSeq i

这使得它分别为0.35。 0.25秒(并产生一个微小的52,936 bytes allocated in the heap)。

现在,如果仍然太慢,你可以担心一个好的回忆策略。我所知道的最好的(1)是使用一个未装箱的数组来存储不超过限制的数字的链长,

{-# LANGUAGE BangPatterns #-}
module Main (main) where

import System.Environment (getArgs)
import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Data.Bits

main :: IO ()
main = do
    args <- getArgs
    let bd = case args of
               a:_ -> read a
               _   -> 100000
    print $ mxColl bd

mxColl :: Int -> (Int,Int)
mxColl bd = runST $ do
    arr <- newArray (0,bd) 0
    unsafeWrite arr 1 1
    goColl arr bd 1 1 2

goColl :: STUArray s Int Int -> Int -> Int -> Int -> Int -> ST s (Int,Int)
goColl arr bd ms ml i
    | bd < i    = return (ms,ml)
    | otherwise = do
        nln <- collatzLength arr bd i
        if ml < nln
          then goColl arr bd i nln (i+1)
          else goColl arr bd ms ml (i+1)

collatzLength :: STUArray s Int Int -> Int -> Int -> ST s Int
collatzLength arr bd n = go 1 n
  where
    go !l 1 = return l
    go l m
        | bd < m    = go (l+1) $ case m .&. 1 of
                                   0 -> m `shiftR` 1
                                   _ -> 3*m+1
        | otherwise = do
            l' <- unsafeRead arr m
            case l' of
              0 -> do
                  l'' <- go 1 $ case m .&. 1 of
                                  0 -> m `shiftR` 1
                                  _ -> 3*m+1
                  unsafeWrite arr m (l''+1)
                  return (l + l'')
              _ -> return (l+l'-1)

使用NCG编译时,0.04秒内的作业限制为1000000,使用LLVM后端编程为0.05(显然,这与NCG优化STUArray代码不太一样)。

如果您没有64位GHC,则不能简单地使用Int,因为对于某些输入,这会溢出。 但绝大部分的计算仍然在Int范围内执行,因此您应该尽可能使用它,并且只在必要时移至Integer

switch :: Int
switch = (maxBound - 1) `quot` 3

back :: Integer
back = 2 * fromIntegral (maxBound :: Int)

cSeq :: Int -> Int
cSeq n = goInt 1 n
  where
    goInt acc 1      = acc
    goInt acc m
      | m .&. 1 == 0 = goInt (acc+1) (m `shiftR` 1)
      | m > switch   = goInteger (acc+1) (3*toInteger m + 1)
      | otherwise    = goInt (acc+1) (3*m+1)
    goInteger acc m
      | fromInteger m .&. (1 :: Int) == 1 = goInteger (acc+1) (3*m+1)
      | m > back  = goInteger (acc+1) (m `quot` 2)  -- yup, quot is faster than shift for Integer here
      | otherwise = goInt (acc + 1) (fromInteger $ m `quot` 2)

使得优化循环更加困难,因此它比使用Int的单循环慢,但仍然不错。这里(永远不会运行Integer循环),NCG需要0.42秒,LLVM后端需要0.37秒(这与使用纯quot中的Int几乎相同版本)。

对于备忘版本使用类似的技巧也会产生类似的后果,它比纯Int版本慢得多,但与未版本的版本相比仍然非常快。


(1)对于这种特殊(类型)问题,您需要记住连续范围的参数的结果。对于其他问题,Map或其他一些数据结构将是更好的选择。

答案 2 :(得分:1)

似乎maximum函数是已经指出的罪魁祸首,但如果使用-O2标志编译程序,则不必担心它。

该程序仍然很慢,这是因为该问题应该教你记忆。执行此操作的一个好方法是使用Data.Memocombinators

进行攻击
import Data.MemoCombinators
import Control.Arrow
import Data.List
import Data.Ord
import System.Environment

play m = maximumBy (comparing snd) . map (second threeNPuzzle) $ zip [1..] [1..m]
  where
    threeNPuzzle = arrayRange (1,m) memoized
    memoized n 
      | n == 1 = 1
      | odd n  = 1 + threeNPuzzle (3*n + 1)
      | even n = 1 + threeNPuzzle (n `div` 2)

main = getArgs >>= print . play . read . head

在我的机器上使用-O2编译时,上述程序会在一秒钟内运行。

请注意,在这种情况下,记住threeNPuzzle找到的所有值并不是一个好主意,上面的程序会记住这些值直到极限(问题中为1000000)。