在Haskell中执行常量空间嵌套循环的正确方法是什么?

时间:2015-09-02 05:15:01

标签: performance loops haskell

有两个明显的,"惯用的"在Haskell中执行嵌套循环的方法:使用list monad或使用seed替换传统的forM_。我已经设置了一个基准来确定它们是否被编译为紧密循环:

fors

此测试创建一个100x100向量,使用嵌套循环向每个索引写入1并重复100k次。仅使用import Control.Monad.Loop import Control.Monad.Primitive import Control.Monad import Control.Monad.IO.Class import qualified Data.Vector.Unboxed.Mutable as MV import qualified Data.Vector.Unboxed as V times = 100000 side = 100 -- Using `forM_` to replace traditional fors test_a mvec = forM_ [0..times-1] $ \ n -> do forM_ [0..side-1] $ \ y -> do forM_ [0..side-1] $ \ x -> do MV.write mvec (y*side+x) 1 -- Using the list monad to replace traditional forms test_b mvec = sequence_ $ do n <- [0..times-1] y <- [0..side-1] x <- [0..side-1] return $ MV.write mvec (y*side+x) 1 main = do let vec = V.generate (side*side) (const 0) mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int) -- test_a mvec -- test_b mvec vec' <- V.unsafeFreeze mvec :: IO (V.Vector Int) print $ V.sum vec' (ghc版本7.8.4)编译的结果为ghc -O2 test.hs -o test版本的 3.853s forM_ 10.460s。为了提供参考,我还用JavaScript编写了这个测试:

list monad

这个等效的JavaScript程序需要 var side = 100; var times = 100000; var vec = []; for (var i=0; i<side*side; ++i) vec.push(0); for (var n=0; n<times; ++n) for (var y=0; y<side; ++y) for (var x=0; x<side; ++x) vec[x+y*side] = 1; var s = 0; for (var i=0; i<side*side; ++i) s += vec[i]; console.log(s); 来完成,击败Haskell的未装箱的向量,这是不寻常的,这表明Haskell没有在恒定空间中运行循环,而是在做而是分配。然后,我发现了一个声称提供类型保证紧密循环的库1s

Control.Monad.Loop

哪个在 -- Using `for` from Control.Monad.Loop test_c mvec = exec_ $ do n <- for 0 (< times) (+ 1) x <- for 0 (< side) (+ 1) y <- for 0 (< side) (+ 1) liftIO (MV.write mvec (y*side+x) 1) 中运行。尽管如此,该库并没有被广泛使用,因此,获得快速恒定空间二维计算的惯用方法是什么?(注意这不是REPA的情况)因为我想在网格上执行任意IO操作。)

2 个答案:

答案 0 :(得分:16)

用GHC编写严格的变异代码有时会很棘手。我会写一些不同的东西,可能是一种比我更喜欢的漫无边际的方式。

对于初学者,我们应该在任何情况下使用GHC 7.10,因为otherwise forM_和列表monad解决方案永远不会融合。

此外,我将MV.write替换为MV.unsafeWrite,部分原因是因为它更快,但更重要的是它减少了生成的Core中的一些混乱。从现在开始,运行时统计信息引用带unsafeWrite的代码。

可怕的浮动

即使使用GHC 7.10,我们也应首先注意所有[0..times-1][0..side-1]表达式,因为如果我们不采取必要步骤,它们每次都会破坏性能。问题是它们是不变的范围,而-ffull-laziness(默认情况下在-O上启用)会将它们浮动到顶层。这可以防止列表融合,并且在Int#范围内迭代比在盒装Int - s的列表上迭代更便宜,因此这是一个非常糟糕的优化。

让我们在几秒钟内看到一些未更改的运行时(除了使用unsafeWrite)代码。使用ghc -O2 -fllvm,我使用+RTS -s进行计时。

test_a: 1.6
test_b: 6.2
test_c: 0.6

对于GHC Core查看,我使用了ghc -O2 -ddump-simpl -dsuppress-all -dno-suppress-type-signatures

test_a的情况下,[0..99]范围被取消:

main4 :: [Int]
main4 = eftInt 0 99 -- means "enumFromTo" for Int.

虽然最外面的[0..9999]循环被融合成一个尾递归帮助器:

letrec {
          a3_s7xL :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a3_s7xL =
            \ (x_X5zl :: Int#) (s1_X4QY :: State# RealWorld) ->
              case a2_s7xF 0 s1_X4QY of _ { (# ipv2_a4NA, ipv3_a4NB #) ->
              case x_X5zl of wild_X1S {
                __DEFAULT -> a3_s7xL (+# wild_X1S 1) ipv2_a4NA;
                99999 -> (# ipv2_a4NA, () #)
              }
              }; }

如果是test_b,则只会取消[0..99] test_b。但是,[IO ()]要慢得多,因为它必须构建和排序实际的[IO ()]列表。至少GHC足够明智,只能为两个内部循环构建一个10000,然后执行 let { lvl7_s4M5 :: [IO ()] lvl7_s4M5 = -- omitted letrec { a2_s7Av :: Int# -> State# RealWorld -> (# State# RealWorld, () #) a2_s7Av = \ (x_a5xi :: Int#) (eta_B1 :: State# RealWorld) -> letrec { a3_s7Au :: [IO ()] -> State# RealWorld -> (# State# RealWorld, () #) a3_s7Au = \ (ds_a4Nu :: [IO ()]) (eta1_X1c :: State# RealWorld) -> case ds_a4Nu of _ { [] -> case x_a5xi of wild1_X1y { __DEFAULT -> a2_s7Av (+# wild1_X1y 1) eta1_X1c; 99999 -> (# eta1_X1c, () #) }; : y_a4Nz ys_a4NA -> case (y_a4Nz `cast` ...) eta1_X1c of _ { (# ipv2_a4Nf, ipv3_a4Ng #) -> a3_s7Au ys_a4NA ipv2_a4Nf } }; } in a3_s7Au lvl7_s4M5 eta_B1; } in -- omitted 次排序。

{-# OPTIONS_GHC -fno-full-laziness #-}

我们如何解决这个问题?我们可以用test_a: 0.5 test_b: 0.48 test_c: 0.5 来解决这个问题。在我们的案例中,这确实有很大帮助:

INLINE

或者,我们可以摆弄-fno-full-laziness pragma。在浮动完成后显然内联函数可以保持良好的性能。我发现即使没有编译指示,GHC也会内联我们的测试函数,但是显式编译指示会导致它仅在浮动后才能内联。例如,如果没有test_a mvec = forM_ [0..times-1] $ \ n -> forM_ [0..side-1] $ \ y -> forM_ [0..side-1] $ \ x -> MV.unsafeWrite mvec (y*side+x) 1 {-# INLINE test_a #-}

,这会带来良好的效果
test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-1] $ \ y -> 
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1
{-# INLINE [~2] test_a #-} -- "inline before the first phase please"

但过早勾勒导致表现不佳:

INLINE

这个INLINE [~2]解决方案的问题在于,面对GHC的浮动攻击,它相当脆弱。例如,手动内联不会保留性能。以下代码很慢,因为与main = do let vec = V.generate (side*side) (const 0) mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int) forM_ [0..times-1] $ \ n -> forM_ [0..side-1] $ \ y -> forM_ [0..side-1] $ \ x -> MV.unsafeWrite mvec (y*side+x) 1 类似,它使GHC有机会浮出水面:

-fno-full-laziness

那我们该怎么办?

首先,我认为对于那些想要编写高性能代码并且知道自己在做什么的人来说,使用Control.Monad.Loop是一个完全可行的,甚至是更好的选择。例如,它在unordered-containers中使用。有了它,我们可以更精确地控制共享,我们可以随时手动浮出或内联。

对于更常规的代码,我相信使用for或任何其他提供该功能的软件包没有任何问题。许多Haskell用户并不一丝不苟地依赖于小型&#34;边缘&#34;库。我们也可以在期望的一般性中重新实现for :: Monad m => a -> (a -> Bool) -> (a -> a) -> (a -> m ()) -> m () for init while step body = go init where go !i | while i = body i >> go (step i) go i = return () {-# INLINE for #-} 。例如,以下内容的表现与其他解决方案一样:

+RTS -s

在非常恒定的空间中循环

我起初对堆分配的test_a数据感到非常困惑。 -fno-full-laziness test_c分配了times非{}},而{/ 1}} 没有完全懒惰,这些分配与test_b次迭代次数呈线性关系,但-- with -fno-full-laziness, no INLINE pragmas test_a: 242,521,008 bytes test_b: 121,008 bytes test_c: 121,008 bytes -- but 240,120,984 with full laziness! 仅为向量分配了完全懒惰:

INLINE

此外,test_c +RTS -s pragma在这种情况下根本没有帮助。

我花了一些时间试图在Core中为相关程序找到堆分配的迹象,但没有成功,直到实现让我感到震惊:GHC堆栈帧在堆上,包括主线程的帧,以及函数正在进行堆分配的实际上是在最多三个堆栈帧中运行三次嵌套循环。由{-# OPTIONS_GHC -fno-full-laziness #-} -- ... test_a mvec = forM_ [0..times-1] $ \ n -> forM_ [0..side-1] $ \ y -> forM_ [0..side-1] $ \ x -> MV.unsafeWrite mvec (y*side+x) 1 main = do let vec = V.generate (side*side) (const 0) mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int) test_a mvec 注册的堆分配只是堆栈帧的持续弹出和推送。

以下代码从Core中可以看出这一点:

main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5HK :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vr { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vr) of _ {
      False ->
        case newByteArray# 80000 (s_a5HK `cast` ...)
        of _ { (# ipv_a5fv, ipv1_a5fw #) ->
        letrec {
          $s$wa_s8jS
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8jS =
            \ (sc_s8jO :: Int#)
              (sc1_s8jP :: Int#)
              (sc2_s8jR :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jP 10000) of _ {
                False -> (# sc2_s8jR, I# sc_s8jO #);
                True ->
                  case writeIntArray# ipv1_a5fw sc_s8jO 0 (sc2_s8jR `cast` ...)
                  of s'#_a5Gn { __DEFAULT ->
                  $s$wa_s8jS (+# sc_s8jO 1) (+# sc1_s8jP 1) (s'#_a5Gn `cast` ...)
                  }
              }; } in
        case $s$wa_s8jS 0 0 (ipv_a5fv `cast` ...)
        -- end of vector creation -------------------

        of _ { (# ipv6_a4Hv, ipv7_a4Hw #) ->
        letrec {
          a2_s7MJ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7MJ =
            \ (x_a5Ho :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7ME :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7ME =
                  \ (x1_X5Id :: Int#) (eta1_XR :: State# RealWorld) ->
                    case ipv7_a4Hw of _ { I# dt4_a5x6 ->
                    case writeIntArray#
                           (ipv1_a5fw `cast` ...) (*# x1_X5Id 100) 1 (eta1_XR `cast` ...)
                    of s'#_a5Gn { __DEFAULT ->
                    letrec {
                      a4_s7Mz :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7Mz =
                        \ (x2_X5J8 :: Int#) (eta2_X1U :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5fw `cast` ...)
                                 (+# (*# x1_X5Id 100) x2_X5J8)
                                 1
                                 (eta2_X1U `cast` ...)
                          of s'#1_X5Hf { __DEFAULT ->
                          case x2_X5J8 of wild_X2o {
                            __DEFAULT -> a4_s7Mz (+# wild_X2o 1) (s'#1_X5Hf `cast` ...);
                            99 -> (# s'#1_X5Hf `cast` ..., () #)
                          }
                          }; } in
                    case a4_s7Mz 1 (s'#_a5Gn `cast` ...)
                    of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
                    case x1_X5Id of wild_X1e {
                      __DEFAULT -> a3_s7ME (+# wild_X1e 1) ipv2_a4QH;
                      99 -> (# ipv2_a4QH, () #)
                    }
                    }
                    }
                    }; } in
              case a3_s7ME 0 eta_B1 of _ { (# ipv2_a4QH, ipv3_a4QI #) ->
              case x_a5Ho of wild_X1a {
                __DEFAULT -> a2_s7MJ (+# wild_X1a 1) ipv2_a4QH;
                99999 -> (# ipv2_a4QH, () #)
              }
              }; } in
        a2_s7MJ 0 (ipv6_a4Hv `cast` ...)
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wm, ww6_a5wn #) ->
                   : ww5_a5wm ww6_a5wn
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

我的荣耀包括在这里。随意跳过。

test_a

我们也可以通过以下方式很好地演示帧的分配。让我们改变test_a mvec = forM_ [0..times-1] $ \ n -> forM_ [0..side-1] $ \ y -> forM_ [0..side-50] $ \ x -> -- change here MV.unsafeWrite mvec (y*side+x) 1

test_a mvec = 
    forM_ [0..times-1] $ \ n -> 
        forM_ [0..side-50] $ \ y -> -- change here
            forM_ [0..side-1] $ \ x -> 
                MV.unsafeWrite mvec (y*side+x) 1

现在堆分配保持完全相同,因为最里面的循环是尾递归并使用单个帧。通过以下更改,堆分配减半(到124,921,008字节),因为我们推送和弹出一半的帧:

test_b

test_cmain(没有完全懒惰)而是编译为在单个堆栈帧内使用嵌套case构造的代码,并遍历索引以查看哪个应该递增。请参阅核心以了解以下{-# LANGUAGE BangPatterns #-} -- later I'll talk about this {-# OPTIONS_GHC -fno-full-laziness #-} main = do let vec = V.generate (side*side) (const 0) !mvec <- V.unsafeThaw vec :: IO (MV.MVector (PrimState IO) Int) test_c mvec

main1 :: State# RealWorld -> (# State# RealWorld, () #)
main1 =
  \ (s_a5Iw :: State# RealWorld) ->
    case divInt# 9223372036854775807 8 of ww4_a5vT { __DEFAULT ->

    -- start of vector creation ----------------------
    case tagToEnum# (># 10000 ww4_a5vT) of _ {
      False ->
        case newByteArray# 80000 (s_a5Iw `cast` ...)
        of _ { (# ipv_a5g3, ipv1_a5g4 #) ->
        letrec {
          $s$wa_s8ji
            :: Int#
               -> Int#
               -> State# (PrimState IO)
               -> (# State# (PrimState IO), Int #)
          $s$wa_s8ji =
            \ (sc_s8je :: Int#)
              (sc1_s8jf :: Int#)
              (sc2_s8jh :: State# (PrimState IO)) ->
              case tagToEnum# (<# sc1_s8jf 10000) of _ {
                False -> (# sc2_s8jh, I# sc_s8je #);
                True ->
                  case writeIntArray# ipv1_a5g4 sc_s8je 0 (sc2_s8jh `cast` ...)
                  of s'#_a5GP { __DEFAULT ->
                  $s$wa_s8ji (+# sc_s8je 1) (+# sc1_s8jf 1) (s'#_a5GP `cast` ...)
                  }
              }; } in
        case $s$wa_s8ji 0 0 (ipv_a5g3 `cast` ...)
        of _ { (# ipv6_a4MX, ipv7_a4MY #) ->
        case ipv7_a4MY of _ { I# dt4_a5xy ->
        -- end of vector creation

        letrec {
          a2_s7Q6 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
          a2_s7Q6 =
            \ (x_a5HT :: Int#) (eta_B1 :: State# RealWorld) ->
              letrec {
                a3_s7Q5 :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                a3_s7Q5 =
                  \ (x1_X5J9 :: Int#) (eta1_XP :: State# RealWorld) ->
                    letrec {
                      a4_s7MZ :: Int# -> State# RealWorld -> (# State# RealWorld, () #)
                      a4_s7MZ =
                        \ (x2_X5Jl :: Int#) (s1_X4Xb :: State# RealWorld) ->
                          case writeIntArray#
                                 (ipv1_a5g4 `cast` ...)
                                 (+# (*# x1_X5J9 100) x2_X5Jl)
                                 1
                                 (s1_X4Xb `cast` ...)
                          of s'#_a5GP { __DEFAULT ->

                          -- the interesting part! ------------------
                          case x2_X5Jl of wild_X1y {
                            __DEFAULT -> a4_s7MZ (+# wild_X1y 1) (s'#_a5GP `cast` ...);
                            99 ->
                              case x1_X5J9 of wild1_X1o {
                                __DEFAULT -> a3_s7Q5 (+# wild1_X1o 1) (s'#_a5GP `cast` ...);
                                99 ->
                                  case x_a5HT of wild2_X1c {
                                    __DEFAULT -> a2_s7Q6 (+# wild2_X1c 1) (s'#_a5GP `cast` ...);
                                    99999 -> (# s'#_a5GP `cast` ..., () #)
                                  }
                              }
                          }
                          }; } in
                    a4_s7MZ 0 eta1_XP; } in
              a3_s7Q5 0 eta_B1; } in
        a2_s7Q6 0 (ipv6_a4MX `cast` ...)
        }
        }
        };
      True ->
        case error
               (unpackAppendCString#
                  "Primitive.basicUnsafeNew: length to large: "#
                  (case $wshowSignedInt 0 10000 ([])
                   of _ { (# ww5_a5wO, ww6_a5wP #) ->
                   : ww5_a5wO ww6_a5wP
                   }))
        of wild_00 {
        }
    }
    }

main :: IO ()
main = main1 `cast` ...

main2 :: State# RealWorld -> (# State# RealWorld, () #)
main2 = runMainIO1 (main1 `cast` ...)

main :: IO ()
main = main2 `cast` ...

瞧:

Control.Monad.Loop

我不得不承认,我基本上不知道为什么有些代码可以避免堆栈框架的创建而有些代码没有。我怀疑内部&#34;内部&#34; out帮助,并且快速检查告诉我Monad.Loop使用CPS编码,这可能与此相关,尽管test_c解决方案对于让浮动很敏感,而我无法确定来自Core的注意事项为什么带有浮动的test_b无法在单个堆栈帧中运行。

现在,在单个堆栈帧中运行的性能优势很小。我们发现test_a只比-O略快。我把这个绕道包括在答案中因为我发现它有启发性。

状态破解和严格绑定

所谓的state hack使GHC积极参与IO和ST行动。我想我应该在这里提一下,因为除了让浮动这是另一件可以彻底破坏性能的事情。

使用优化import Control.Monad import Debug.Trace expensive :: String -> String expensive x = trace "$$$" x main :: IO () main = do str <- fmap expensive getLine replicateM_ 3 $ print str 启用状态黑客,并且可能会渐进地减慢程序的速度。来自Reid Barton的简单示例:

"$$$"

使用GHC-7.10.2,这会在没有优化的情况下打印-O2一次,但在-fno-state-hack打印三次。而且似乎使用GHC-7.10,我们无法用main :: IO () main = do !str <- fmap expensive getLine replicateM_ 3 $ print str (这是Reid Barton链接票证的主题)摆脱这种行为。

严格的monadic绑定可靠地摆脱了这个问题:

-fno-full-laziness

我认为在IO和ST中进行严格绑定是个好习惯。我有一些经验(虽然不是确定的;我远不是GHC专家),如果我们使用test_b,则特别需要严格的绑定。显然,完全懒惰可以帮助摆脱由国家黑客引起的内联引入的一些工作重复;使用!mvec <- V.unsafeThaw vec并且没有完全懒惰,省略{{1}}上的严格绑定会导致轻微的减速和非常难看的Core输出。

答案 1 :(得分:3)

根据我的经验forM_ [0..n-1]可以表现良好,但不幸的是,它不可靠。只需将INLINE pragma添加到test_a并使用-O2即可让它运行得更快(对我而言为4s到1s),但手动内联它(复制粘贴)会再次降低速度。

更可靠的功能是来自forstatistics,其实现为

-- | Simple for loop.  Counts from /start/ to /end/-1.
for :: Monad m => Int -> Int -> (Int -> m ()) -> m ()
for n0 !n f = loop n0
  where
    loop i | i == n    = return ()
           | otherwise = f i >> loop (i+1)
{-# INLINE for #-}

使用它看起来类似于带有列表的forM_

test_d :: MV.IOVector Int -> IO ()
test_d mv =
  for 0 times $ \_ ->
    for 0 side $ \i ->
      for 0 side $ \j ->
        MV.unsafeWrite mv (i*side + j) 1

但执行得非常好(对我来说是0.85秒),没有任何分配列表的风险。