我正在做一些Project Euler项目(不是作为功课,只是为了娱乐/学习),而我正在学习Haskell。其中一个问题是找到最大的Collatz序列,起始编号低于100万(http://projecteuler.net/problem=14)
所以,无论如何,我能够做到这一点,我的算法在编译时能够正常运行并获得正确的答案。但是,它使用1000000深度递归。
所以我的问题是:我做对了吗?原样,Haskell的正确方法是什么?我怎么能让它更快?另外,在内存使用情况下,如何在低级别实际实现递归?如何使用记忆?
( SPOILER ALERT:如果您想在不查看答案的情况下自行解决Project Euler问题#14,请不要看这个。)
- haskell脚本 --problem:找到一个不到200万的最长的collatz链。
collatzLength x| x == 1 = 1
| otherwise = 1 + collatzLength(nextStep x)
longestChain (num, numLength) bound counter
| counter >= bound = (num, numLength)
| otherwise = longestChain (longerOf (num,numLength)
(counter, (collatzLength counter)) ) bound (counter + 1)
--I know this is a messy function, but I was doing this problem just
--for myself, so I didn't bother making some utility functions for it.
--also, I split the big line in half to display on here nicer, would
--it actually run with this line split?
longerOf (a1,a2) (b1,b2)| a2 > b2 = (a1,a2)
| otherwise = (b1,b2)
nextStep n | mod n 2 == 0 = (n `div` 2)
| otherwise = 3*n + 1
main = print (longestChain (0,0) 1000000 1)
当使用-O2编译时,程序运行大约7.5秒。
那么,有什么建议/意见吗?我想尝试让程序运行得更快,内存使用量更少,我希望用非常Haskellian(应该是一个单词)的方式来实现。
提前致谢!
答案 0 :(得分:7)
编辑以回答问题
我做对了吗?
几乎,正如评论所说,你构建了一个很大的1+(1+(1+...))
- 使用严格的累加器或者更高级的函数为你处理事情。还有其他一些小问题,比如定义一个函数来比较第二个元素而不是使用maximumBy (comparing snd)
,但这更具风格。
原样,Haskell的正确方法是什么?
这是可以接受的惯用Haskell代码。
我怎样才能让它更快?
请参阅下面的基准测试。欧拉表现问题的极为常见的答案是:
rem
代替mod
。对于您的情况,了解或发现div
倾向于编译为慢于quot
的内容也很有用。另外,在内存使用情况下,如何在低级别实际实现递归? 如何使用记忆?
这两个问题都非常广泛。完整的答案可能需要解决延迟评估,尾部调用优化,工作人员转换,垃圾收集等问题。我建议您随着时间的推移更深入地探索这些答案(或希望有人在这里做出我正在避免的完整答案)。
原帖 - 基准数字
原件:
$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main ( so.hs, so.o )
Linking so ...
(837799,525)
real 0m5.971s
user 0m5.940s
sys 0m0.019s
使用带有累加器的worker函数collatzLength
:
$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main ( so.hs, so.o )
Linking so ...
(837799,525)
real 0m5.617s
user 0m5.590s
sys 0m0.012s
使用Int
而非默认为Integer
- 使用类型签名也更容易阅读!
$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main ( so.hs, so.o )
Linking so ...
(837799,525)
real 0m2.937s
user 0m2.932s
sys 0m0.001s
使用rem
而非mod
:
$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main ( so.hs, so.o )
Linking so ...
(837799,525)
real 0m2.436s
user 0m2.431s
sys 0m0.001s
使用quotRem
而非rem
然后使用div
:
$ ghc -O2 so.hs ; time ./so
[1 of 1] Compiling Main ( so.hs, so.o )
Linking so ...
(837799,525)
real 0m1.672s
user 0m1.669s
sys 0m0.002s
这与上一个问题非常相似:Speed comparison with Project Euler: C vs Python vs Erlang vs Haskell
编辑:是的,正如Daniel Fischer建议的那样,使用.&.
和shiftR
的位操作会改进quotRem
:
$ ghc -O2 so.hs ; time ./so
(837799,525)
real 0m0.314s
user 0m0.312s
sys 0m0.001s
或者你可以只使用LLVM并让它做起来很神奇(注意这个版本仍使用quotRem
)
$ time ./so
(837799,525)
real 0m0.286s
user 0m0.283s
sys 0m0.002s
LLVM实际上运行良好,只要你避免使用mod
的可怕性,并使用rem
或even
优化基于防护的代码,同样优化手工优化.&.
与shiftR
。
对于比原始结果快约20倍的结果。
编辑:人们惊讶于,在面对Int
时,quotRem和位操作一样好。代码包含在内,但我并不清楚:只是因为某些东西可能是负面的并不意味着你无法使用非常相似的位操作来处理它,而这些位操作可能在正确的硬件上具有相同的成本。 nextStep
的所有三个版本似乎都表现相同(ghc -O2 -fforce-recomp -fllvm
,ghc版本7.6.3,LLVM 3.3,x86-64)。
{-# LANGUAGE BangPatterns, UnboxedTuples #-}
import Data.Bits
collatzLength :: Int -> Int
collatzLength x| x == 1 = 1
| otherwise = go x 0
where
go 1 a = a + 1
go x !a = go (nextStep x) (a+1)
longestChain :: (Int, Int) -> Int -> Int -> (Int,Int)
longestChain (num, numLength) bound !counter
| counter >= bound = (num, numLength)
| otherwise = longestChain (longerOf (num,numLength) (counter, collatzLength counter)) bound (counter + 1)
--I know this is a messy function, but I was doing this problem just
--for myself, so I didn't bother making some utility functions for it.
--also, I split the big line in half to display on here nicer, would
--it actually run with this line split?
longerOf :: (Int,Int) -> (Int,Int) -> (Int,Int)
longerOf (a1,a2) (b1,b2)| a2 > b2 = (a1,a2)
| otherwise = (b1,b2)
{-# INLINE longerOf #-}
nextStep :: Int -> Int
-- Version 'bits'
nextStep n = if 0 == n .&. 1 then n `shiftR` 1 else 3*n+1
-- Version 'quotRem'
-- nextStep n = let (q,r) = quotRem n 2 in if r == 0 then q else 3*n+1
-- Version 'almost the original'
-- nextStep n | even n = quot n 2
-- | otherwise = 3*n + 1
{-# INLINE nextStep #-}
main = print (longestChain (0,0) 1000000 1)