Haskell性能拓扑排序不够快

时间:2018-03-07 23:05:02

标签: performance haskell

我是Haskell的初学者并选择它来解决我班级的编程任务,但是我的解决方案太慢而且没有被接受。我试图对其进行分析,并希望我能从这里得到更高级的Haskellers的一些指示。

到目前为止我班上唯一接受的其他解决方案是用Rust编写的。我确信我应该能够在Haskell中实现类似的性能,并且为了提高性能而编写了可怕的命令性代码,但无济于事。

我的第一个怀疑与work有关,我使用forever来查看度数数组,直到我遇到越界异常。我希望这是尾递归并编译为while (true)样式循环。

我的第二个怀疑是I / O可能会减慢速度。

编辑:问题可能与我的算法有关,因为我没有保留具有indegree 0的节点队列。谢谢@luqui。

EDIT2:似乎真正的瓶颈是I / O,由于@Davislor,我修复了这个问题。

任务基于:http://www.spoj.com/UKCPLAD/problems/TOPOSORT/,我只能使用Haskell平台中的库。

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
{-# OPTIONS_GHC -O3 #-}

import Control.Monad
import Data.Array.IO
import Data.IORef
import Data.Int
import Control.Exception

type List = []
type Node = Int32
type Edge = (Node, Node)
type Indegree = Int32

main = do
  (numNodes, _) <- readPair <$> getLine
  edges <- map readPair . lines <$> getContents
  topo numNodes edges

-- lower bound
{-# INLINE lb #-}
lb = 1

topo :: Node -> List Edge -> IO ()
topo numNodes edges = do
    result <- newIORef []
    count <- newIORef 0
    indegrees <- newArray (lb,numNodes) 0 :: IO (IOUArray Node Indegree)
    neighbours <- newArray (lb,numNodes) [] :: IO (IOArray Node (List Node))
    forM_ edges $ \(from,to) -> do
      update indegrees to (+1)
      update neighbours from (to:)
    let work = forever $ do
          z <- getNext indegrees
          modifyIORef' result (z:)
          modifyIORef' count (+1)
          ns <- readArray neighbours z
          forM_ ns $ \n -> update indegrees n pred
    work `catch`
      \(_ :: SomeException) -> do
        count <- readIORef count
        if numNodes == count
          then (mapM_ (\n -> putStr (show n ++ " ")) . reverse) =<< readIORef result
          else putStrLn "Sandro fails."


{-# INLINE update #-}
update a i f = do
  x <- readArray a i
  writeArray a i (f x)

{-# INLINE getNext #-}
getNext indegrees = getNext' indegrees =<< getBounds indegrees

{-# INLINE getNext' #-}
getNext' indegrees (lb,ub) = readArray indegrees lb >>= \case
    0 -> writeArray indegrees lb (-1) >> return lb
    _ -> getNext' indegrees (lb+1,ub)

readPair :: String -> (Node,Node)
{-# INLINE readPair #-}
readPair = toPair . map read . words
  where toPair [x,y] = (x,y)
        toPair _ = error "Only two entries per line allowed"

示例输出

$ ./topo
8 9
1 4
1 2
4 2
4 3
3 2
5 2
3 5
8 2
8 6
^D
1 4 3 5 7 8 2 6

2 个答案:

答案 0 :(得分:3)

如果您还没有,profile your program通过编译-prof -fprof-auto然后使用命令行选项+RTS -p执行。这将生成一个配置文件*.prof,它将告诉您程序一直在花费哪些功能。但是,我可以立即看到最大的浪费时间。你的直觉是正确的:它是I / O.

做了很多,我可以向你保证,你会发现它花了大部分时间做I / O.您应该始终做的第一件事就是加快程序速度,重写它以使用快速I / O.当您使用正确的数据结构时,Haskell是一种快速语言 Prelude中的默认I / O库使用具有延迟评估的thunks的单链接列表,其中每个节点都包含一个Unicode字符。在C中也会很慢!

当输入为ASCII时,Data.ByteString.Lazy.Char8获得最佳结果,Data.ByteString.Builder生成输出。 (另一种选择是Data.Text。)这会让你在输入上得到一个延迟评估的严格字符缓冲区列表(因此交互式输入和输出仍然有效),并在输出上填充单个缓冲区。

使用快速I / O编写程序框架后,下一步是查看算法,尤其是数据结构。使用分析来查看所有时间的位置。但是我建议你使用一个函数式算法,而不是试图用do在Haskell中编写命令式程序。

我几乎总是在Haskell中使用更具功能性的方式处理这样的问题:特别是,我的main函数几乎总是类似于:

import qualified Data.ByteString.Lazy.Char8 as B8

main :: IO()
main = B8.interact ( output . compute . input )

除了调用interact之外的所有内容都是一个纯函数,并隔离了解析代码和格式代码,因此中间的compute部分可以独立于此。

由于这是一项任务,你想自己解决问题,我将不再为你重构程序,但这是我在另一个论坛上回答问题时写的一个例子来执行计数排序。它应该适合作为其他类型问题的骨架。

import Data.Array.IArray (accumArray, assocs)
import Data.Array.Unboxed (UArray)
import Data.ByteString.Builder (Builder, char7, intDec, toLazyByteString)
import qualified Data.ByteString.Lazy.Char8 as B8
import Data.Monoid ((<>))

main :: IO()
main = B8.interact ( output . compute . input ) where
  input :: B8.ByteString -> [Int]
  input = map perLine . tail . B8.lines where
    perLine = decode . B8.readInt

    decode (Just (x, _)) = x
    decode Nothing = error "Invalid input: expected integer."

  compute :: [Int] -> [Int]
  compute = concatMap expand . assocs . countingSort . map encode where
    encode i = (i, 1)

    countingSort :: [(Int, Int)] -> UArray Int Int
    countingSort = accumArray (+) 0 (lower, upper)

    lower = 0
    upper = 1000000

    expand (i,c) = replicate c i

  output :: [Int] -> B8.ByteString
  output = toLazyByteString . foldMap perCase where
    perCase :: Int -> Builder
    perCase x = intDec x <> char7 '\n'

目前,这个版本的运行时间不到其他任何人Haskell solution for the same problem的一半,对于我用过它的实际比赛问题也是如此,the approach generalizes

因此,我建议将I / O更改为类似于第一个,然后进行性能分析,然后返回分析输出,如果这不足以产生差异。这也可能是一个很好的Code Review问题。

答案 1 :(得分:2)

感谢@ Davislor的建议,我设法让它更快更多,我还重构了代码,现在我实际上有 m log( n )算法。令人惊讶的是,这并没有产生太大的影响 - I / O远远超过了算法的次优复杂性。

编辑:摆脱unsafePerformIO,它实际上运行得更快一点。再加上-XStrict剃掉了更多的时间。

{-# LANGUAGE Strict #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -O2 #-}

import Control.Monad
import Data.Array.IO
import Data.Int
import Data.Set (Set)
import qualified Data.Set as Set
import Data.ByteString.Builder (Builder, char7, intDec, toLazyByteString)
import qualified Data.ByteString.Lazy.Char8 as B8
import Data.Monoid ((<>))


type List = []
type Node = Int
type Edge = (Node, Node)
type Indegree = Int

main = B8.putStrLn =<< topo . map readPair . B8.lines =<< B8.getContents

readPair :: B8.ByteString -> (Node,Node)
readPair str = (x,y)
  where
    (Just (x, str')) = B8.readInt str
    (Just (y, _   )) = B8.readInt (B8.tail str')

topo :: List Edge -> IO B8.ByteString
topo inp = do
    let (numNodes, _) = head inp
        edges         = tail inp
    indegrees <- newArray (1,numNodes) 0 :: IO (IOUArray Node Indegree)
    neighbours <- newArray (1,numNodes) [] :: IO (IOArray Node (List Node))

    -- setup
    forM_ edges $ \(from,to) -> do
      update indegrees to (+1)
      update neighbours from (to:)

    zeroes <- collectIndegreeZero [] indegrees =<< getBounds indegrees
    processQueue (Set.fromList zeroes) [] numNodes indegrees neighbours

  where
    collectIndegreeZero acc indegrees (lb,ub)
      | lb > ub = return acc
      | otherwise = do
          indegr <- readArray indegrees lb
          let acc' = if indegr == 0 then (lb:acc) else acc
          collectIndegreeZero acc' indegrees (lb+1,ub)

    processQueue queue result numNodes indegrees neighbours = do
        if null queue
          then if numNodes == 0
              then return . toLazyByteString . foldMap whitespace . reverse $ result
              else return "Sandro fails."
          else do
            (node,queue) <- return $ Set.deleteFindMin queue
            ns <- readArray neighbours node
            queue <- foldM decrIndegrees queue ns
            processQueue queue (node:result) (numNodes-1) indegrees neighbours
      where
        decrIndegrees :: Set Node -> Node -> IO (Set Node)
        decrIndegrees q n = do
            i <- readArray indegrees n
            writeArray indegrees n (i-1)
            return $ if i == 1 then Set.insert n q else q

        whitespace x = intDec x <> char7 ' '

{-# INLINE update #-}
update a i f = do
  x <- readArray a i
  writeArray a i (f x)