为什么这个简单的haskell算法这么慢?

时间:2011-12-28 17:45:53

标签: haskell collatz

剧透警报:这与Project Euler的Problem 14有关。

以下代码需要大约15秒才能运行。我有一个在1s内运行的非递归Java解决方案。我想我应该能够更接近这个代码。

import Data.List

collatz a 1  = a
collatz a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

main = do
  print ((foldl1' max) . map (collatz 1) $ [1..1000000])

我已经使用+RHS -p进行了分析,并注意到分配的内存很大,并随着输入的增长而增长。对于n = 100,000分配了1gb(!),为n = 1,000,000分配了13gb(!!)。

然后,-sstderr表明尽管分配了大量字节,但总内存使用量为1mb,生产率为95%+,因此13gb可能是红鲱鱼。

我可以想到几个可能性:

  1. 某些事情并不像它需要的那样严格。我已经发现了 foldl1',但也许我需要做更多?是否可以标记collatz 严格(甚至有意义吗?

  2. collatz不是尾调用优化。我认为应该但不是 知道确认的方法。

  3. 例如,编译器没有做我认为应该做的一些优化 只有两个collatz的结果在任何时候都需要在内存中(最大值和当前值)

  4. 有什么建议吗?

    这几乎与Why is this Haskell expression so slow?重复,但我会注意到快速Java解决方案不必执行任何memoization。有没有什么方法可以加快速度而不必诉诸它?

    作为参考,这是我的分析输出:

      Wed Dec 28 09:33 2011 Time and Allocation Profiling Report  (Final)
    
         scratch +RTS -p -hc -RTS
    
      total time  =        5.12 secs   (256 ticks @ 20 ms)
      total alloc = 13,229,705,716 bytes  (excludes profiling overheads)
    
    COST CENTRE                    MODULE               %time %alloc
    
    collatz                        Main                  99.6   99.4
    
    
                                                                                                   individual    inherited
    COST CENTRE              MODULE                                               no.    entries  %time %alloc   %time %alloc
    
    MAIN                     MAIN                                                   1           0   0.0    0.0   100.0  100.0
     CAF                     Main                                                 208          10   0.0    0.0   100.0  100.0
      collatz                Main                                                 215           1   0.0    0.0     0.0    0.0
      main                   Main                                                 214           1   0.4    0.6   100.0  100.0
       collatz               Main                                                 216           0  99.6   99.4    99.6   99.4
     CAF                     GHC.IO.Handle.FD                                     145           2   0.0    0.0     0.0    0.0
     CAF                     System.Posix.Internals                               144           1   0.0    0.0     0.0    0.0
     CAF                     GHC.Conc                                             128           1   0.0    0.0     0.0    0.0
     CAF                     GHC.IO.Handle.Internals                              119           1   0.0    0.0     0.0    0.0
     CAF                     GHC.IO.Encoding.Iconv                                113           5   0.0    0.0     0.0    0.0
    

    和-sstderr:

    ./scratch +RTS -sstderr 
    525
      21,085,474,908 bytes allocated in the heap
          87,799,504 bytes copied during GC
               9,420 bytes maximum residency (1 sample(s))          
              12,824 bytes maximum slop               
                   1 MB total memory in use (0 MB lost due to fragmentation)  
    
      Generation 0: 40219 collections,     0 parallel,  0.40s,  0.51s elapsed
      Generation 1:     1 collections,     0 parallel,  0.00s,  0.00s elapsed
    
      INIT  time    0.00s  (  0.00s elapsed)
      MUT   time   35.38s  ( 36.37s elapsed)
      GC    time    0.40s  (  0.51s elapsed)
      RP    time    0.00s  (  0.00s elapsed)  PROF  time    0.00s  (  0.00s elapsed)
      EXIT  time    0.00s  (  0.00s elapsed)
      Total time   35.79s  ( 36.88s elapsed)  %GC time       1.1%  (1.4% elapsed)  Alloc rate    595,897,095 bytes per MUT second
    
      Productivity  98.9% of total user, 95.9% of total elapsed
    

    Java解决方案(不是我的,取自Project Euler论坛并删除了memoization):

    public class Collatz {
      public int getChainLength( int n )
      {
        long num = n;
        int count = 1;
        while( num > 1 )
        {
          num = ( num%2 == 0 ) ? num >> 1 : 3*num+1;
          count++;
        }
        return count;
      }
    
      public static void main(String[] args) {
        Collatz obj = new Collatz();
        long tic = System.currentTimeMillis();
        int max = 0, len = 0, index = 0;
        for( int i = 3; i < 1000000; i++ )
        {
          len = obj.getChainLength(i);
          if( len > max )
          {
            max = len;
            index = i;
          }
        }
        long toc = System.currentTimeMillis();
        System.out.println(toc-tic);
        System.out.println( "Index: " + index + ", length = " + max );
      }
    }
    

3 个答案:

答案 0 :(得分:21)

起初,我认为你应该尝试在collatz中的 a 之前加上感叹号:

collatz !a 1  = a
collatz !a x
  | even x    = collatz (a + 1) (x `div` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

(您需要将{-# LANGUAGE BangPatterns #-}放在源文件的顶部才能生效。)

我的理由如下:问题在于你在第一个争论中建立一个巨大的 thunk :它从1开始,然后变成{{ 1}},然后变成1 + 1,......所有这一切都没有被强迫。这个bang pattern强制(1 + 1) + 1的第一个参数在每次调用时被强制,所以它从1开始,然后变为2,依此类推,而不会构建一个大的未评估的thunk:它只是保持整数。

请注意,爆炸模式只是使用seq的简写;在这种情况下,我们可以按如下方式重写collatz

collatz

这里的技巧是在守卫中强制 a ,然后总是评估为False(因此身体是无关紧要的)。然后继续评估下一个案例, a 已经过评估。然而,爆炸模式更清晰。

不幸的是,当使用collatz a _ | seq a False = undefined collatz a 1 = a collatz a x | even x = collatz (a + 1) (x `div` 2) | otherwise = collatz (a + 1) (3 * x + 1) 进行编译时,这不会比原始版本运行得更快!我们还能尝试什么?好吧,我们可以做的一件事是假设这两个数字永远不会溢出一个机器大小的整数,并给这个类型的注释-O2

collatz

我们将把爆炸模式留在那里,因为即使它们不是性能问题的根源,我们仍然应该避免积累thunk。这使我的(慢)计算机上的时间减少到8.5秒。

下一步是尝试将其更接近Java解决方案。要实现的第一件事是,在Haskell中,collatz :: Int -> Int -> Int 在负数整数方面表现得更加数学上正确,但比“正常”C除法慢,在Haskell中称为div。用quot替换div会将运行时间降低到5.2秒,并将quot替换为x `quot` 2(导入Data.Bits)以匹配Java解决方案,将其降低到4.9秒

现在这个数字差不多我可以得到它,但我认为这是一个非常好的结果;由于您的计算机速度比我的快,因此应该更接近Java解决方案。

这是最终的代码(我在路上做了一些清理工作):

x `shiftR` 1

观看GHC核心这个项目(ghc-core),我认为这可能和它一样好; {-# LANGUAGE BangPatterns #-} import Data.Bits import Data.List collatz :: Int -> Int collatz = collatz' 1 where collatz' :: Int -> Int -> Int collatz' !a 1 = a collatz' !a x | even x = collatz' (a + 1) (x `shiftR` 1) | otherwise = collatz' (a + 1) (3 * x + 1) main :: IO () main = print . foldl1' max . map collatz $ [1..1000000] 循环使用未装箱的整数,程序的其余部分看起来没问题。我能想到的唯一改进就是从collatz迭代中消除拳击。

顺便说一句,不要担心“总分配”数字;它是在程序生命周期内分配的总内存,即使GC回收该内存,它也不会减少。多TB的数字很常见。

答案 1 :(得分:2)

你可能会失去列表和爆炸模式,而是通过使用堆栈来获得相同的性能。

import Data.List
import Data.Bits

coll :: Int -> Int
coll 0 = 0
coll 1 = 1
coll 2 = 2
coll n =
  let a = coll (n - 1)
      collatz a 1 = a
      collatz a x
        | even x    = collatz (a + 1) (x `shiftR` 1)
        | otherwise = collatz (a + 1) (3 * x + 1)
  in max a (collatz 1 n)


main = do
  print $ coll 100000

这样做的一个问题是你必须为大输入增加堆栈的大小,比如1_000_000。

<强>更新

这是一个尾递归版本,不会遇到堆栈溢出问题。

import Data.Word
collatz :: Word -> Word -> (Word, Word)
collatz a x
  | x == 1    = (a,x)
  | even x    = collatz (a + 1) (x `quot` 2)
  | otherwise = collatz (a + 1) (3 * x + 1)

coll :: Word -> Word
coll n = collTail 0 n
  where
    collTail m 1 = m
    collTail m n = collTail (max (fst $ collatz 1 n) m) (n-1)

请注意使用Word代替Int。它在性能上有所不同。如果你愿意,你仍然可以使用爆炸模式,这几乎会使性能翻倍。

答案 2 :(得分:0)

我发现的一件事在这个问题上产生了惊人的差异。我坚持直接复发关系而不是折叠,你应该原谅表达,用它来计算。重写

collatz n = if even n then n `div` 2 else 3 * n + 1

作为

collatz n = case n `divMod` 2 of
            (n', 0) -> n'
            _       -> 3 * n + 1
在具有2.8 GHz Athlon II X4 430 CPU的系统上,我的程序运行时间缩短了1.2秒。我最初的更快版本(使用divMod后2.3秒):

{-# LANGUAGE BangPatterns #-}

import Data.List
import Data.Ord

collatzChainLen :: Int -> Int
collatzChainLen n = collatzChainLen' n 1
    where collatzChainLen' n !l
            | n == 1    = l
            | otherwise = collatzChainLen' (collatz n) (l + 1)

collatz:: Int -> Int
collatz n = case n `divMod` 2 of
                 (n', 0) -> n'
                 _       -> 3 * n + 1

pairMap :: (a -> b) -> [a] -> [(a, b)]
pairMap f xs = [(x, f x) | x <- xs]

main :: IO ()
main = print $ fst (maximumBy (comparing snd) (pairMap collatzChainLen [1..999999]))

或许更惯用的Haskell版本在大约9.7秒内运行(带有divMod的8.5);除了

之外,它是完全相同的
collatzChainLen :: Int -> Int
collatzChainLen n = 1 + (length . takeWhile (/= 1) . (iterate collatz)) n

使用Data.List.Stream应该允许流融合,这将使该版本运行更像具有显式累积,但我找不到具有Data.List.Stream的Ubuntu libghc *包,所以我还不能验证。