有没有办法内联递归函数?

时间:2017-02-11 18:44:00

标签: haskell

这是我的previous question的后续行动,我问为什么流融合并没有在某个程序中踢。事实证明,问题在于某些函数没有内联,并且INLINE标志将性能提高了大约17x(这表明了内联的重要性!)。

现在,请注意,在原始问题上,我同时对64incAll次电话进行了硬编码。现在,假设我创建了一个nTimes函数,它重复调用一个函数:

module Main where

import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

{-# INLINE nTimes #-}
nTimes :: Int -> (a -> a) -> a -> a
nTimes 0 f x = x
nTimes n f x = f (nTimes (n-1) f x)

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print $ V.sum (nTimes 64 incAll array)

在这种情况下,仅向INLINE添加nTimes pragma将无济于事,因为AFAIK GHC不会内联递归函数。是否有任何技巧迫使GHC在编译时扩展nTimes,从而恢复预期的性能?

3 个答案:

答案 0 :(得分:27)

不,但你可以使用更好的功能。我不是在谈论V.map (+64),这会让事情变得更快,但关于nTimes。我们有三位候选人已经完成了nTimes所做的事情:

{-# INLINE nTimesFoldr #-}
nTimesFoldr :: Int -> (a -> a) -> a -> a    
nTimesFoldr n f x = foldr (.) id (replicate n f) $ x

{-# INLINE nTimesIterate #-}
nTimesIterate :: Int -> (a -> a) -> a -> a    
nTimesIterate n f x = iterate f x !! n

{-# INLINE nTimesTail #-}
nTimesTail :: Int -> (a -> a) -> a -> a    
nTimesTail n f = go n
  where
    {-# INLINE go #-}
    go n x | n <= 0 = x
    go n x          = go (n - 1) (f x)

所有版本大约需要8秒,而版本需要40秒。顺便提一下,约阿希姆的版本还需要8秒。请注意,iterate版本在我的系统上占用更多内存。虽然GHC有unroll plugin,但在过去五年中它没有更新(它使用自定义的ANNotations)。

根本没有展开?

然而,在我们绝望之前,GHC实际上是如何试图内联一切的?让我们使用nTimesTailnTimes 1

module Main where
import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

{-# INLINE nTimes #-}
nTimes :: Int -> (a -> a) -> a -> a    
nTimes n f = go n
  where
    {-# INLINE go #-}
    go n x | n <= 0 = x
    go n x          = go (n - 1) (f x)

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print $ V.sum (nTimes 1 incAll array)
$ stack ghc --package vector -- -O2 -ddump-simpl -dsuppress-all SO.hs
main2 =
  case (runSTRep main3) `cast` ...
  of _ { Vector ww1_s9vw ww2_s9vx ww3_s9vy ->
  case $wgo 1 ww1_s9vw ww2_s9vx ww3_s9vy
  of _ { (# ww5_s9w3, ww6_s9w4, ww7_s9w5 #) ->

我们可以在那里停下来。 $wgo是上面定义的go。即使使用1 GHC也不会展开循环。这是令人不安的,因为1是一个常数。

救援模板

但是,唉,并非一无所获。如果C ++程序员能够为编译时常量执行以下操作,我们也应该这样做,对吗?

template <int N>
struct Call{
    template <class F, class T>
    static T call(F f, T && t){
        return f(Call<N-1>::call(f,std::forward<T>(t)));
    }
};
template <>
struct Call<0>{
    template <class F, class T>
    static T call(F f, T && t){
        return t;
    }  
};

果然,我们可以TemplateHaskell *

-- Times.sh
{-# LANGUAGE TemplateHaskell #-}
module Times where

import Control.Monad (when)
import Language.Haskell.TH

nTimesTH :: Int -> Q Exp
nTimesTH n = do
  f <- newName "f"
  x <- newName "x"

  when (n <= 0) (reportWarning "nTimesTH: argument non-positive")

  let go k | k <= 0 = VarE x
      go k          = AppE (VarE f) (go (k - 1))
  return $ LamE [VarP f,VarP x] (go n)

nTimesTH做什么?它会创建一个新功能,其中名字f将应用于第二个名称x,总共n次。 n现在需要是一个适合我们的编译时常量,因为只有编译时常量才能进行循环展开:

$(nTimesTH 0) = \f x -> x
$(nTimesTH 1) = \f x -> f x
$(nTimesTH 2) = \f x -> f (f x)
$(nTimesTH 3) = \f x -> f (f (f x))
...

有用吗?它快吗?与nTimes相比有多快?让我们尝试另一个main

-- SO.hs
{-# LANGUAGE TemplateHaskell #-}
module Main where
import Times
import qualified Data.Vector.Unboxed as V

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

{-# INLINE nTimes #-}
nTimes :: Int -> (a -> a) -> a -> a    
nTimes n f = go n
  where
    {-# INLINE go #-}
    go n x | n <= 0 = x
    go n x          = go (n - 1) (f x)

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  let vTH   = V.sum ($(nTimesTH 64) incAll array)
  let vNorm = V.sum (nTimes 64 incAll array)
  print $ vTH == vNorm
stack ghc --package vector -- -O2 SO.hs && SO.exe +RTS -t
True
<<ghc: 52000056768 bytes, 66 GCs, 400034700/800026736 avg/max bytes residency (2 samples), 1527M in use, 0.000 INIT (0.000 elapsed), 8.875 MUT (9.119 elapsed), 0.000 GC (0.094 elapsed) :ghc>>

它产生正确的结果。它有多快?让我们再次使用另一个main

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print $ V.sum ($(nTimesTH 64) incAll array)
     800,048,112 bytes allocated in the heap                                         
           4,352 bytes copied during GC                                              
          42,664 bytes maximum residency (1 sample(s))                               
          18,776 bytes maximum slop                                                  
             764 MB total memory in use (0 MB lost due to fragmentation)             

                                     Tot time (elapsed)  Avg pause  Max pause        
  Gen  0         1 colls,     0 par    0.000s   0.000s     0.0000s    0.0000s        
  Gen  1         1 colls,     0 par    0.000s   0.049s     0.0488s    0.0488s        

  INIT    time    0.000s  (  0.000s elapsed)                                         
  MUT     time    0.172s  (  0.221s elapsed)                                         
  GC      time    0.000s  (  0.049s elapsed)                                         
  EXIT    time    0.000s  (  0.049s elapsed)                                         
  Total   time    0.188s  (  0.319s elapsed)                                         

  %GC     time       0.0%  (15.3% elapsed)                                           

  Alloc rate    4,654,825,378 bytes per MUT second                                   

  Productivity 100.0% of total user, 58.7% of total elapsed        

那么,将它与8s进行比较。所以对于 TL; DR :如果你有编译时常量,并且想要根据该常量创建和/或修改代码,请考虑模板Haskell。

*请注意,这是我编写的第一个模板Haskell代码。小心使用。不要使用过大的n,否则最终可能会出现混乱的功能。

答案 1 :(得分:15)

Andres之前告诉我一个鲜为人知的技巧,你可以通过使用类型类实际获得GHC到内联递归函数。

这个想法是,而不是通常在一个值上执行结构递归的函数。您可以使用类型类定义函数,并对类型参数执行结构递归。在此示例中,键入级别自然数。

GHC将很乐意内联每个递归调用并生成有效的代码,因为每个递归调用都是不同的类型。

我没有对此进行基准测试或查看核心,但速度明显更快。

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where

import qualified Data.Vector.Unboxed as V

data Proxy a = Proxy

{-# INLINE incAll #-}
incAll :: V.Vector Int -> V.Vector Int
incAll = V.map (+ 1)

oldNTimes :: Int -> (a -> a) -> a -> a
oldNTimes 0 f x = x
oldNTimes n f x = f (oldNTimes (n-1) f x)

-- New definition

data N = Z | S N

class Unroll (n :: N) where
    nTimes :: Proxy n -> (a -> a) -> a -> a

instance Unroll Z where
    nTimes _ f x = x

instance Unroll n => Unroll (S n) where
    nTimes p f x =
        let Proxy :: Proxy (S n) = p
        in f (nTimes (Proxy :: Proxy n) f x)

main :: IO ()
main = do
  let size = 100000000 :: Int
  let array = V.replicate size 0 :: V.Vector Int
  print $ V.sum (nTimes (Proxy :: Proxy (S (S (S (S (S (S (S (S (S (S (S Z)))))))))))) incAll array)
  print $ V.sum (oldNTimes 11 incAll array)

答案 2 :(得分:4)

没有

你可以写

{-# INLINE nTimes #-}
nTimes :: Int -> (a -> a) -> a -> a
nTimes n f x = go n
  where go 0 = x
        go n = f (go (n-1))

并且GHC会内联nTimes,并且可能会将递归go专门用于您的特定参数incAllarray,但它不会展开循环。