我正在研究项目Euler #14,并且有一个解决方案来获得答案,但是当我尝试运行代码时,我遇到了堆栈空间溢出错误。该算法在交互式GHCI中工作正常(在低数字上),但是当我向它抛出一个非常大的数字并尝试编译它时,它不会工作。
以下是交互式ghci中它的作用的粗略概念。在我的电脑上计算“回答50000”大约需要10秒钟。
编辑:让GHCI运行问题几分钟后,它会吐出正确答案。
*Euler System.IO> answer 1000000
(525,837799)
但是,当编译程序本机运行时,这并没有解决堆栈溢出错误。
*Euler System.IO> answer 10
(20,9)
*Euler System.IO> answer 100
(119,97)
*Euler System.IO> answer 1000
(179,871)
*Euler System.IO> answer 10000
(262,6171)
*Euler System.IO> answer 50000
(324,35655)
我该怎么做才能得到“回答1000000”的答案?我想我的算法需要稍微调整一下,但我不知道如何去做。
这是我到目前为止编写的代码:
module Main
where
import System.IO
import Control.Monad
main = print (answer 1000000)
-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total
-- length of the chain
count' n = (cSeq n, n)
cSeq n = length $ game n
-- Find the maximum chain value of the game
answer n = maximum $ map count' [1..n]
-- Working game.
-- game 13 = [13,40,20,10,5,16,8,4,2,1]
game n = n : play n
play x
| x <= 0 = [] -- is negative or 0
| x == 1 = [] -- is 1
| even x = doEven x : play ((doEven x)) -- even
| otherwise = doOdd x : play ((doOdd x)) -- odd
where doOdd x = (3 * x) + 1
doEven x = (x `div` 2)
答案 0 :(得分:4)
这里的问题是maximum
太懒了。它不是跟踪最大的元素,而是构建了一个巨大的max
thunk树。这是因为maximum
是根据foldl
定义的,因此评估如下:
maximum [1, 2, 3, 4, 5]
foldl max 1 [2, 3, 4, 5]
foldl max (max 1 2) [3, 4, 5]
foldl max (max (max 1 2) 3) [4, 5]
foldl max (max (max (max 1 2) 3) 4) [5]
foldl max (max (max (max (max 1 2) 3) 4) 5) []
max (max (max (max 1 2) 3) 4) 5 -- this expression will be huge for large lists
尝试评估太多这些嵌套的max
调用会导致堆栈溢出。
解决方案是强制它使用严格版本foldl'
(或者,在本例中为其表兄foldl1'
)进行评估。这可以防止max
在每个步骤中减少它们来构建:
foldl1' max [1, 2, 3, 4, 5]
foldl' max 1 [2, 3, 4, 5]
foldl' max 2 [3, 4, 5]
foldl' max 3 [4, 5]
foldl' max 4 [5]
foldl' max 5 []
5
如果你使用-O2
进行编译,GHC通常可以自行解决这些问题,其中({等等}}对你的程序进行严格的分析。但是,我认为编写不需要依赖优化工作的程序是一种很好的做法。
注意:修复此问题后,生成的程序仍然很慢。您可能希望使用memoization来解决此问题。
答案 1 :(得分:4)
@hammar已经pointed out maximum
太懒了,以及如何解决这个问题(使用foldl1'
严格版foldl1
)。
但代码中存在进一步的低效率。
cSeq n = length $ game n
cSeq
让game
构建一个列表,只计算其长度。不幸的是,length
不是“好消费者”,因此中间列表的构造不会被融合。这是相当多的不必要的分配和成本时间。消除这些列表
cSeq n = coll (1 :: Int) n
where
coll acc 1 = acc
coll acc m
| even m = coll (acc + 1) (m `div` 2)
| otherwise = coll (acc + 1) (3*m+1)
将分配减少了65%,运行时间减少了约20%(仍然很慢)。接下来,您正在使用div
,除正常分割外,还会执行签名检查。由于所涉及的所有数字都是正数,因此使用quot
会加快速度(在这里不多,但稍后会变得很重要)。
接下来的重点是,因为你没有给出类型签名,数字的类型(除了使用length
或表达式类型签名(1 :: Int)
确定的位置)在我的重写中)是Integer
。 Integer
上的操作比Int
上的相应操作慢得多,因此如果可能,您应该使用Int
(或Word
)而不是Integer
时速度很重要。如果您有64位GHC,Int
就足以进行这些计算,那么使用div
时运行时间减少一半,使用quot
时运行时间减少约70%本机代码生成器,使用LLVM后端时,使用div
时运行时间减少约70%,使用quot
时运行时间减少约95%。
本机代码生成器和LLVM后端之间的区别主要是由于一些基本的低级优化。
even
和odd
已定义
even, odd :: (Integral a) => a -> Bool
even n = n `rem` 2 == 0
odd = not . even
GHC.Real
中的。当类型为Int
时,LLVM知道将除法替换为2用于通过按位和(n .&. 1 == 0
)确定模数。本机代码生成器(尚未)执行许多这些低级优化。如果您手动执行此操作,NCG和LLVM后端生成的代码几乎完全相同。
当使用div
时,NCG和LLVM都无法用短移位和加法序列替换除法,因此使用符号测试得到相对较慢的机器除法指令。使用quot
,两者都可以为Int
执行此操作,因此您可以获得更快的代码。
所有出现的数字都是正数的知识允许我们用简单的右移替换除2,而没有任何代码来纠正否定参数,这使得LLVM后端产生的代码加速了另外~33%,奇怪的是,它对NCG没有任何影响。
所以从最初的8秒加/减(使用NCG少一点,使用LLVM后端稍微多一点),我们已经去了
module Main (main)
where
import Data.List
import Data.Bits
main = print (answer (1000000 :: Int))
-- Count the length of the sequences
-- count' creates a tuple with the second value
-- being the starting number of the game
-- and the first value being the total
-- length of the chain
count' n = (cSeq n, n)
cSeq n = go (1 :: Int) n
where
go !acc 1 = acc
go acc m
| even' m = go (acc+1) (m `shiftR` 1)
| otherwise = go (acc+1) (3*m+1)
even' :: Int -> Bool
even' m = m .&. 1 == 0
-- Find the maximum chain value of the game
answer n = foldl1' max $ map count' [1..n]
使用NCG需要0.37秒,在我的设置上使用LLVM后端需要0.27秒。
运行时间略有改善,但通过手动递归替换foldl1' max
可以大大减少分配,
answer n = go 1 1 2
where
go ml mi i
| n < i = (ml,mi)
| l > ml = go l i (i+1)
| otherwise = go ml mi (i+1)
where
l = cSeq i
这使得它分别为0.35。 0.25秒(并产生一个微小的52,936 bytes allocated in the heap
)。
现在,如果仍然太慢,你可以担心一个好的回忆策略。我所知道的最好的(1)是使用一个未装箱的数组来存储不超过限制的数字的链长,
{-# LANGUAGE BangPatterns #-}
module Main (main) where
import System.Environment (getArgs)
import Data.Array.ST
import Data.Array.Base
import Control.Monad.ST
import Data.Bits
main :: IO ()
main = do
args <- getArgs
let bd = case args of
a:_ -> read a
_ -> 100000
print $ mxColl bd
mxColl :: Int -> (Int,Int)
mxColl bd = runST $ do
arr <- newArray (0,bd) 0
unsafeWrite arr 1 1
goColl arr bd 1 1 2
goColl :: STUArray s Int Int -> Int -> Int -> Int -> Int -> ST s (Int,Int)
goColl arr bd ms ml i
| bd < i = return (ms,ml)
| otherwise = do
nln <- collatzLength arr bd i
if ml < nln
then goColl arr bd i nln (i+1)
else goColl arr bd ms ml (i+1)
collatzLength :: STUArray s Int Int -> Int -> Int -> ST s Int
collatzLength arr bd n = go 1 n
where
go !l 1 = return l
go l m
| bd < m = go (l+1) $ case m .&. 1 of
0 -> m `shiftR` 1
_ -> 3*m+1
| otherwise = do
l' <- unsafeRead arr m
case l' of
0 -> do
l'' <- go 1 $ case m .&. 1 of
0 -> m `shiftR` 1
_ -> 3*m+1
unsafeWrite arr m (l''+1)
return (l + l'')
_ -> return (l+l'-1)
使用NCG编译时,0.04秒内的作业限制为1000000,使用LLVM后端编程为0.05(显然,这与NCG优化STUArray
代码不太一样)。
如果您没有64位GHC,则不能简单地使用Int
,因为对于某些输入,这会溢出。
但绝大部分的计算仍然在Int
范围内执行,因此您应该尽可能使用它,并且只在必要时移至Integer
。
switch :: Int
switch = (maxBound - 1) `quot` 3
back :: Integer
back = 2 * fromIntegral (maxBound :: Int)
cSeq :: Int -> Int
cSeq n = goInt 1 n
where
goInt acc 1 = acc
goInt acc m
| m .&. 1 == 0 = goInt (acc+1) (m `shiftR` 1)
| m > switch = goInteger (acc+1) (3*toInteger m + 1)
| otherwise = goInt (acc+1) (3*m+1)
goInteger acc m
| fromInteger m .&. (1 :: Int) == 1 = goInteger (acc+1) (3*m+1)
| m > back = goInteger (acc+1) (m `quot` 2) -- yup, quot is faster than shift for Integer here
| otherwise = goInt (acc + 1) (fromInteger $ m `quot` 2)
使得优化循环更加困难,因此它比使用Int
的单循环慢,但仍然不错。这里(永远不会运行Integer
循环),NCG需要0.42秒,LLVM后端需要0.37秒(这与使用纯quot
中的Int
几乎相同版本)。
对于备忘版本使用类似的技巧也会产生类似的后果,它比纯Int
版本慢得多,但与未版本的版本相比仍然非常快。
(1)对于这种特殊(类型)问题,您需要记住连续范围的参数的结果。对于其他问题,Map
或其他一些数据结构将是更好的选择。
答案 2 :(得分:1)
似乎maximum
函数是已经指出的罪魁祸首,但如果使用-O2
标志编译程序,则不必担心它。
该程序仍然很慢,这是因为该问题应该教你记忆。执行此操作的一个好方法是使用Data.Memocombinators
:
import Data.MemoCombinators
import Control.Arrow
import Data.List
import Data.Ord
import System.Environment
play m = maximumBy (comparing snd) . map (second threeNPuzzle) $ zip [1..] [1..m]
where
threeNPuzzle = arrayRange (1,m) memoized
memoized n
| n == 1 = 1
| odd n = 1 + threeNPuzzle (3*n + 1)
| even n = 1 + threeNPuzzle (n `div` 2)
main = getArgs >>= print . play . read . head
在我的机器上使用-O2
编译时,上述程序会在一秒钟内运行。
请注意,在这种情况下,记住threeNPuzzle找到的所有值并不是一个好主意,上面的程序会记住这些值直到极限(问题中为1000000)。