异常缓慢的Haskell代码

时间:2014-07-09 22:06:50

标签: haskell

在F#中完成后,我一直在尝试使用Haskell中的Digits-Recognizer Dojo进行练习。我得到了结果,但由于某种原因,我的Haskell代码疯狂慢,我似乎无法找到错误。

这是我的代码(可以在Dojo的GitHub上找到.csv个文件):

import Data.Char
import Data.List
import Data.List.Split
import Data.Ord
import System.IO

type Pixels = [Int]

data Digit = Digit { label  :: Int, pixels :: Pixels }

distance :: Pixels -> Pixels -> Float
distance d1 d2 = sqrt . sum $ map pointDistance $ zip d1 d2
    where pointDistance (a, b) = fromIntegral $ (a - b) * (a - b)

parseDigit :: String -> Digit
parseDigit s = Digit label pixels
    where (label:pixels) = map read $ splitOn "," s

identify :: Digit -> [Digit] -> (Digit, Float)
identify digit training = minimumBy (comparing snd) distances
    where distances = map fn training
          fn ref    = (ref, distance (pixels digit) (pixels ref))

readDigits :: String -> IO [Digit]
readDigits filename = do
    fileContent <- readFile filename
    return $ map parseDigit $ tail $ lines fileContent

main :: IO ()
main = do
    trainingSample <- readDigits "trainingsample.csv"
    validationSample <- readDigits "validationsample.csv"
    let result               = [(d, identify d trainingSample) | d <- validationSample]
        fmt (d, (ref, dist)) = putStrLn $ "Found..."
    mapM_ fmt result

这些糟糕表现的原因是什么?


[更新] 感谢您提出的许多想法!我已根据建议将String的使用情况更改为Data.Text,并将List的使用情况更改为Data.Vector,遗憾的是结果仍然不尽如人意。

我的更新代码为available here

为了让您更好地理解我的审讯,这里是我的Haskell(左)和F#(右)实现的输出。我是这两种语言的全新手,所以我真诚地相信我的Haskell版本中存在一个重大错误,就是那么慢

Terminal capture

2 个答案:

答案 0 :(得分:6)

如果您有耐心,您会注意到第二个结果的计算速度比第一个快得多。那是因为你的实现需要一些时间来读取csv文件。

您可能会想要粘贴一个打印语句,看看它何时完成加载:

main = do
    trainingSample <- readDigits "trainingsample.csv"
    validationSample <- readDigits "validationsample.csv"
    putStrLn "done loading data"

但由于懒惰,这不会做你认为它做的事情。 trainingSamplevalidationSample尚未完全评估。所以你的print语句几乎会立即打印出来,第一个结果仍然需要永远。

您可以强制readDigits完全评估他们的返回值,这样可以让您更好地了解在那里花了多少时间。您可以切换到使用非惰性IO,或只打印从数据派生的内容:

readDigits :: String -> IO [Digit]
readDigits filename = do
    fileContent <- readFile filename
    putStr' $ filename ++ ": "
    rows <- forM (tail $ lines fileContent) $ \line -> do
      let xs = parseDigit line
      putStr' $ case compare (sum $ pixels xs) 0 of
                LT -> "-"
                EQ -> "0"
                GT -> "+"
      return xs
    putStrLn ""
    return rows
  where putStr' s = putStr s >> hFlush stdout

在我的机器上,让我看到完全读取trainingsample.csv的数字需要大约27秒。

这是printf风格的分析,它不是很好(使用真正的分析器要好得多,或者使用标准来对代码的各个部分进行基准测试),但是对于这些目的来说足够好。

这显然是经济放缓的主要部分,因此值得尝试转向严格的io。使用严格的Data.Text.IO.readFile将其缩短到约18秒。


更新

以下是加快更新代码的方法:

  1. 使用未装箱的向量Pixels(小赢):

    import qualified Data.Vector.Unboxed as U
    -- ...
    
    type Pixels = U.Vector Int
    -- ...
    
    distance :: Pixels -> Pixels -> Float
    distance d1 d2 = sqrt . U.sum $ U.zipWith pointDistance d1 d2
        where pointDistance a b = fromIntegral $ (a - b) * (a - b)
    
    parseDigit :: T.Text -> Digit
    parseDigit s = Digit label (U.fromList pixels)
        where (label:pixels) = map toDigit $ T.splitOn (T.pack ",") s
              toDigit s      = either (\_ -> 0) fst (T.Read.decimal s)
    
  2. 使用seq(大赢)强制进行距离评估:

    identify :: Digit -> V.Vector Digit -> (Digit, Float)
    identify digit training = V.minimumBy (comparing snd) distances
        where distances = V.map fn training
              fn ref    = let d = distance (pixels digit) (pixels ref) in d `seq` (ref, d)
    
  3. 在我的机器上,整个程序现在运行~5s:

    % ghc --make -O2 Main.hs
    [1 of 1] Compiling Main             ( Main.hs, Main.o )
    Linking Main ...
    % time ./Main
    ./Main  5.00s user 0.11s system 99% cpu 5.115 total
    

    th th杀了你。

答案 1 :(得分:3)

您的Vector版本,部分取消装箱,适用于ByteString并使用-O2 -fllvm编译,在我的机器上运行8秒钟:

import Data.Ord
import Data.Maybe
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC


type Pixels = U.Vector Int

data Digit = Digit { label :: !Int, pixels :: !Pixels }


distance :: Pixels -> Pixels -> Float
distance d1 d2 = sqrt . U.sum . U.zipWith pointDistance d1 $ d2
    where pointDistance a b = fromIntegral $ (a - b) * (a - b)

parseDigit :: B.ByteString -> Digit
parseDigit bs =
    let (label:pixels) = toIntegers bs []
    in Digit label (U.fromList pixels)
    where
      toIntegers bs is =
          let Just (i,bs') = BC.readInt bs
          in if B.null bs' then reverse is else toIntegers (BC.tail bs') (i:is)               

identify :: Digit -> V.Vector Digit -> (Digit, Float)
identify digit training = V.minimumBy (comparing snd) distances
    where distances = V.map fn training
          fn ref    = (ref, distance (pixels digit) (pixels ref))

readDigits :: String -> IO (V.Vector Digit)
readDigits filename = do
    fileContent <- B.readFile filename
    return . V.map parseDigit . V.fromList . tail . BC.lines $ fileContent


main :: IO ()
main = do
    trainingSample <- readDigits "trainingsample.csv"
    validationSample <- readDigits "validationsample.csv"
    let result = V.map (\d -> (d, identify d trainingSample)) validationSample
        fmt (d, (ref, dist)) = putStrLn $ "Found " ++ show (label ref) ++ " for " ++ show (label d) ++ " (distance=" ++ show dist ++ ")"
    V.mapM_ fmt result

+RTS -s的输出:

     989,632,984 bytes allocated in the heap
      19,875,368 bytes copied during GC
      31,016,504 bytes maximum residency (5 sample(s))
      22,748,608 bytes maximum slop
              78 MB total memory in use (1 MB lost due to fragmentation)

                                    Tot time (elapsed)  Avg pause  Max pause
  Gen  0      1761 colls,     0 par    0.05s    0.05s     0.0000s    0.0008s
  Gen  1         5 colls,     0 par    0.00s    0.02s     0.0030s    0.0085s

  INIT    time    0.00s  (  0.00s elapsed)
  MUT     time    7.42s  (  7.69s elapsed)
  GC      time    0.05s  (  0.06s elapsed)
  EXIT    time    0.00s  (  0.01s elapsed)
  Total   time    7.47s  (  7.77s elapsed)

  %GC     time       0.7%  (0.8% elapsed)

  Alloc rate    133,419,569 bytes per MUT second

  Productivity  99.3% of total user, 95.5% of total elapsed