我正在Kattis上做这个汽车游戏问题:https://open.kattis.com/problems/cargame 有五秒的时间限制,但在最后一个实例中,我的代码需要更长的时间才能运行。我很确定我做的是正确的(从大O角度来看)所以现在我需要以某种方式对其进行优化。 我从以下位置下载了测试数据: http://challenge.csc.kth.se/2013/challenge-2013.tar.bz2
从分析中,似乎大多数运行时间都花在containsSub上,这只不过是一个数组访问和一个尾递归调用。此外,它仅被称为约100M次,所以为什么需要6.5秒才能运行(我的笔记本电脑上的时间为6.5秒。我发现Kattis的速度通常是慢两倍,所以可能更像是13秒)。在统计页面上,一些C ++解决方案在不到一秒的时间内运行。甚至一些python解决方案也只是在5秒钟之内。
module Main where
import Control.Monad
import Data.Array (Array, (!), (//))
import qualified Data.Array as Array
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import Data.Char
import Data.List
import Data.Maybe
main::IO()
main = do
[n, m] <- readIntsLn
dictWords <- replicateM n BS.getLine
let suffixChains = map (\w -> (w, buildChain w)) dictWords
replicateM_ m $ findChain suffixChains
noWordMsg :: ByteString
noWordMsg = BS.pack "No valid word"
findChain :: [(ByteString, WordChain)] -> IO ()
findChain suffixChains = do
chrs <- liftM (BS.map toLower) BS.getLine
BS.putStrLn
(
case find (containsSub chrs . snd) suffixChains of
Nothing -> noWordMsg
Just (w, _) -> w
)
readAsInt :: BS.ByteString -> Int
readAsInt = fst . fromJust . BS.readInt
readIntsLn :: IO [Int]
readIntsLn = liftM (map readAsInt . BS.words) BS.getLine
data WordChain = None | Rest (Array Char WordChain)
emptyChars :: WordChain
emptyChars = Rest . Array.listArray ('a', 'z') $ repeat None
buildChain :: ByteString -> WordChain
buildChain s =
case BS.uncons s of
Nothing -> emptyChars
Just (hd, tl) ->
let wc@(Rest m) = buildChain tl in
Rest $ m // [(hd, wc)]
containsSub :: ByteString -> WordChain -> Bool
containsSub _ None = False
containsSub s (Rest m) =
case BS.uncons s of
Nothing -> True
Just (hd, tl) -> containsSub tl (m ! hd)
编辑:TAKE 2:
我尝试构建一个懒惰的特里,以避免搜索我已经搜索过的东西。例如,如果我已经遇到以'a'开头的三元组,那么将来我可以跳过任何不包含'a'的内容。如果我已经从'ab'开始搜索三元组,我可以跳过任何不包含'ab'的内容。如果我已经搜索了确切的三元组'abc',我可以从上次返回相同的结果。从理论上讲,这应该会带来显着的加速。在实践中,运行时间是相同的。
此外,如果没有seq,分析将永远进行并给出虚假结果(我无法猜测为什么)。 使用seqs,分析表明大部分时间花在forLetter上(这是数组访问已被移动到的位置,所以看起来数组访问是缓慢的部分)
{-# LANGUAGE TupleSections #-}
module Main where
import Control.Monad
import Data.Array (Array, (!), (//))
import qualified Data.Array as Array
import qualified Data.Array.Base as Base
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as BS
import Data.Char
import Data.Functor
import Data.Maybe
main::IO()
main = do
[n, m] <- readIntsLn
dictWords <- replicateM n BS.getLine
let suffixChainsL = map (\w -> (w, buildChain w)) dictWords
let suffixChains = foldr seq suffixChainsL suffixChainsL
suffixChains `seq` doProbs m suffixChains
noWordMsg :: ByteString
noWordMsg = BS.pack "No valid word"
doProbs :: Int -> [(ByteString, WordChain)] -> IO ()
doProbs m chains = replicateM_ m doProb
where
cf = findChain chains
doProb =
do
chrs <- liftM (map toLower) getLine
BS.putStrLn . fromMaybe noWordMsg $ cf chrs
findChain :: [(ByteString, WordChain)] -> String -> Maybe ByteString
findChain [] = const Nothing
findChain suffixChains@(shd : _) = doFind
where
letterMap :: Array Char (String -> Maybe ByteString)
letterMap =
Array.listArray ('a','z')
[findChain (mapMaybe (forLetter hd) suffixChains) | hd <- [0..25]]
endRes = Just $ fst shd
doFind :: String -> Maybe ByteString
doFind [] = endRes
doFind (hd : tl) = (letterMap ! hd) tl
forLetter :: Int -> (ByteString, WordChain) -> Maybe (ByteString, WordChain)
forLetter c (s, WC wc) = (s,) <$> wc `Base.unsafeAt` c
readAsInt :: BS.ByteString -> Int
readAsInt = fst . fromJust . BS.readInt
readIntsLn :: IO [Int]
readIntsLn = liftM (map readAsInt . BS.words) BS.getLine
newtype WordChain = WC (Array Char (Maybe WordChain))
emptyChars :: WordChain
emptyChars = WC . Array.listArray ('a', 'z') $ repeat Nothing
buildChain :: ByteString -> WordChain
buildChain = BS.foldr helper emptyChars
where
helper :: Char -> WordChain -> WordChain
helper hd wc@(WC m) = m `seq` WC (m // [(hd, Just wc)])
答案 0 :(得分:2)
uncons
中的containsSub
来电创建了一个新的ByteString
。尝试使用索引跟踪字符串中的偏移量来加快速度,例如:
containsSub' :: ByteString -> WordChain -> Bool
containsSub' str wc = go 0 wc
where len = BS.length str
go _ None = False
go i (Rest m) | i >= len = True
| otherwise = go (i+1) (m ! BS.index str i)
答案 1 :(得分:1)
经过对#haskell和#ghc IRC频道的讨论,我发现问题与这个ghc错误有关:https://ghc.haskell.org/trac/ghc/ticket/1168
解决方案只是改变doProbs的定义
doProbs m chains = cf `seq` replicateM_ m doProb
...
或者只是用-fno-state-hack编译
ghc的状态黑客优化导致它在每次调用时不必要地重新计算cf(及相关的letterMap)。
因此它与数组访问无关。