提高在Haskell中查找图形直径的性能

时间:2018-01-03 16:46:42

标签: performance haskell graph functional-programming

我正在解决以下problem,其本质上是#34;在Haskell中找到连接的无向加权图的直径"。现在,下面的解决方案产生了正确的答案,但超过了9/27测试的时间限制。我远离Haskell神童,你们能不能使用内置的Data.Graph模块来告诉我是否以及如何提高解决方案的性能?我尝试在某些地方使用累加器参数,严格配对和严格评估,但要么我使用不正确,要么性能问题在其他地方。提前谢谢!

import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.List (maximumBy)
import Data.Ord (comparing)

buildGraph :: [Int] -> Map.Map Int [(Int, Int)] -> Map.Map Int [(Int, Int)]
buildGraph [] acc                  = acc
buildGraph (from:to:dist:rest) acc = let withTo = Map.insertWith (++) from [(to, dist)] acc
                                        withFromTo = Map.insertWith (++) to [(from, dist)] withTo
                                        in buildGraph rest $ withFromTo

data Queue a = Queue {
                ingoing :: [a]
                , outgoing :: [a]
            } deriving Show

toQueue xs = Queue [] xs
enqMany xs (Queue is os) = (Queue (reverse xs ++ is) os)
deq (Queue is []) = deq (Queue [] $ reverse is)
deq (Queue is (o:os)) = (o, Queue is os)

extract :: (Ord a) => a -> Map.Map a [b] -> [b]
extract k m = case Map.lookup k m of
                    Just value -> value
                    Nothing    -> error "sdfsd" -- should never happen

bfs node graph = bfs' Set.empty (toQueue [(node, 0)]) []
    where
        bfs' :: Set.Set Int -> Queue (Int, Int) -> [(Int, Int)] -> [(Int, Int)]
        bfs' visited (Queue [] []) acc = acc
        bfs' visited que acc = let ((n, dist), rest) = deq que
                                    in if Set.member n visited
                                            then bfs' visited rest acc
                                            else let children = map (\(i, d) -> (i, d + dist)) $ extract n graph
                                                    newNodes = enqMany children rest
                                                    in bfs' (Set.insert n visited) newNodes ((n, dist):acc)

findMostDistant xs = maximumBy (comparing snd) xs

solve input = answer
    where
        -- the first number is the number of edges and is not necessary
        (_:triples) = map read $ words input
        graph = buildGraph triples Map.empty
        -- pick arbitary node, find the farther node from it using bfs
        (mostDistant, _) = findMostDistant $ bfs (head triples) graph
        -- find the farthest node from the previously farthest node, counting the distance on the way
        (_, answer) = findMostDistant $ bfs mostDistant graph

tests = [
            "11 2 7 2 1 7 6 5 1 8 2 8 6 8 6 9 10 5 5 9 1 9 0 10 15 3 1 21 6 4 3" -- 54
            , "5 3 4 3 0 3 4 0 2 6 1 4 9" -- 22
            , "16 2 3 92 5 2 10 14 3 42 2 4 26 14 12 50 4 6 93 9 6 24 15 14 9 0 2 95 8 0 90 0 13 60 9 10 59 1 0 66 11 12 7 7 10 35" -- 428
        ]

runZeroTests = mapM_ print $ map solve tests

main = do
    answer <- solve <$> getContents
    print answer

3 个答案:

答案 0 :(得分:2)

我认为

deq (Queue [] [])会导致无限循环。

答案 1 :(得分:1)

当我解决Haskell中的竞赛问题时,通常最大的性能问题是慢速I / O库,它在宽字符的惰性线性链表上运行。我总是为编程竞赛做的第一件事是用快速I / O替换它,

这是一个对程序逻辑进行微小更改的版本,只是用Data.ByteString.Lazy.Char8替换I / O,使用延迟评估的严格字节数组列表实现,Data.ByteString.Builder构建一个用于填充输出缓冲区的函数。仅从快速I / O计算加速量应该是有用的。

{-# LANGUAGE OverloadedStrings #-} -- Added

import Data.ByteString.Builder
  (Builder, char7, intDec, toLazyByteString) -- Added
import qualified Data.ByteString.Lazy.Char8 as B8 -- Added
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.List (maximumBy)
import Data.Maybe (fromJust) -- Added
import Data.Monoid ((<>)) -- Added
import Data.Ord (comparing)

buildGraph :: [Int] -> Map.Map Int [(Int, Int)] -> Map.Map Int [(Int, Int)]
buildGraph [] acc                  = acc
buildGraph (from:to:dist:rest) acc = let withTo = Map.insertWith (++) from [(to, dist)] acc
                                         withFromTo = Map.insertWith (++) to [(from, dist)] withTo
                                        in buildGraph rest $ withFromTo

data Queue a = Queue {
                ingoing :: [a]
                , outgoing :: [a]
            } deriving Show

toQueue xs = Queue [] xs
enqMany xs (Queue is os) = (Queue (reverse xs ++ is) os)
deq (Queue is []) = deq (Queue [] $ reverse is)
deq (Queue is (o:os)) = (o, Queue is os)

extract :: (Ord a) => a -> Map.Map a [b] -> [b]
extract k m = case Map.lookup k m of
                    Just value -> value
                    Nothing    -> error "sdfsd" -- should never happen

bfs node graph = bfs' Set.empty (toQueue [(node, 0)]) []
    where
        bfs' :: Set.Set Int -> Queue (Int, Int) -> [(Int, Int)] -> [(Int, Int)]
        bfs' visited (Queue [] []) acc = acc
        bfs' visited que acc = let ((n, dist), rest) = deq que
                                    in if Set.member n visited
                                            then bfs' visited rest acc
                                            else let children = map (\(i, d) -> (i, d + dist)) $ extract n graph
                                                     newNodes = enqMany children rest
                                                    in bfs' (Set.insert n visited) newNodes ((n, dist):acc)

findMostDistant xs = maximumBy (comparing snd) xs

solve triples = answer -- Changed (by deleting one line)
    where
        graph = buildGraph triples Map.empty
        -- pick arbitary node, find the farther node from it using bfs
        (mostDistant, _) = findMostDistant $ bfs (head triples) graph
        -- find the farthest node from the previously farthest node, counting the distance on the way
        (_, answer) = findMostDistant $ bfs mostDistant graph

tests = [ -- Unchanged, but now interpreted as OverloadedStrings
            "11 2 7 2 1 7 6 5 1 8 2 8 6 8 6 9 10 5 5 9 1 9 0 10 15 3 1 21 6 4 3" -- 54
            , "5 3 4 3 0 3 4 0 2 6 1 4 9" -- 22
            , "16 2 3 92 5 2 10 14 3 42 2 4 26 14 12 50 4 6 93 9 6 24 15 14 9 0 2 95 8 0 90 0 13 60 9 10 59 1 0 66 11 12 7 7 10 35" -- 428
        ]

runZeroTests = B8.putStr -- Changed
  . toLazyByteString
  . foldMap format
  . map (solve . parse)
  $ tests

main :: IO () -- Changed
main = B8.interact ( toLazyByteString . format . solve . parse )

parse :: B8.ByteString -> [Int] -- Added
-- the first number is the number of edges and is not necessary
parse = map (fst . fromJust . B8.readInt) . tail . B8.words

format :: Int -> Builder -- Added
format n = intDec n <> eol where
  eol = char7 '\n'

答案 2 :(得分:1)

在@Davislor的帮助下,使用ByteString进行IO以及其他一些事情,我设法得到了100分的问题。最后,我为优化它所做的是:

  1. 使用ByteString IO作为@Davislor建议
  2. 由于我知道输入中的整数是有效的,因此我编写了自己的parseInt函数,该函数不执行不必要的检查。
  3. 我使用Map来创建邻接列表,而不是延迟Array。我不知道使用Array构造accumArray的渐近复杂性是什么(我相信它应该是O(n)),但是数组中的查找应该是O(1),而不是O(log n)的{​​{1}}。
  4. 以下是最终解决方案:

    Map

    仍有改进的余地,访问过的节点的{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE BangPatterns #-} import Data.ByteString.Builder (Builder, char7, intDec, toLazyByteString) import qualified Data.ByteString.Lazy.Char8 as B8 import qualified Data.Set as Set import Data.Monoid ((<>)) import Data.Char (ord) import Data.ByteString (getLine) import Data.Array (Array, array, accumArray, (!), (//)) buildAdjList :: Int -> [Int] -> Array Int [(Int, Int)] buildAdjList n xs = accumArray (flip (:)) [] (0, n) $ triples xs [] where triples [] res = res triples (x:y:dist:rest) res = let edgeXY = (x, (y, dist)) edgeYX = (y, (x, dist)) in triples rest (edgeXY:edgeYX:res) data Queue a = Queue { ingoing :: [a] , outgoing :: [a] } deriving Show enqMany xs (Queue is os) = Queue (reverse xs ++ is) os deq (Queue [] []) = error "gosho" deq (Queue is []) = deq (Queue [] $ reverse is) deq (Queue is (o:os)) = (o, Queue is os) bfs !node adjList = let start = (node, 0) in bfs' Set.empty (Queue [] [start]) start where bfs' :: Set.Set Int -> Queue (Int, Int) -> (Int, Int) -> (Int, Int) bfs' visited (Queue [] []) !ans = ans bfs' visited que !ans = let (curr@(n, dist), rest) = deq que in if Set.member n visited then bfs' visited rest ans else let children = map (\(i, d) -> (i, d + dist)) $ adjList ! n newNodes = enqMany children rest in bfs' (Set.insert n visited) newNodes (longerEdge curr ans) longerEdge :: (Int, Int) -> (Int, Int) -> (Int, Int) longerEdge a b = if (snd a) < (snd b) then b else a parseInt :: B8.ByteString -> Int parseInt str = parseInt' str 0 where parseInt' str !acc | B8.null str = acc | otherwise = parseInt' (B8.tail str) $ ((ord $ B8.head str) - 48 + acc * 10) parseIntList :: B8.ByteString -> [Int] parseIntList = map parseInt . B8.words solve :: [Int] -> Int solve (n:triples) = answer where graph = buildAdjList n triples -- pick arbitary node, find the farther node from it using bfs (mostDistant, _) = bfs (head triples) graph -- find the farthest node from the previously farthest node, counting the distance on the way (_, answer) = bfs mostDistant graph main :: IO () main = B8.interact ( toLazyByteString . intDec . solve . parseIntList ) -- debug code below tests = [ "11 2 7 2 1 7 6 5 1 8 2 8 6 8 6 9 10 5 5 9 1 9 0 10 15 3 1 21 6 4 3" -- 54 , "5 3 4 3 0 3 4 0 2 6 1 4 9" -- 22 , "16 2 3 92 5 2 10 14 3 42 2 4 26 14 12 50 4 6 93 9 6 24 15 14 9 0 2 95 8 0 90 0 13 60 9 10 59 1 0 66 11 12 7 7 10 35" -- 428 ] runZeroTests = B8.putStr . toLazyByteString . foldMap format . map (solve . parseIntList) $ tests format :: Int -> Builder format n = intDec n <> eol where eol = char7 '\n' 可以更改为位数组,可以使用Set代替Int32Int可以应用,虽然我觉得我无法理解Haskell程序的执行顺序。