在F#中完成后,我一直在尝试使用Haskell中的Digits-Recognizer Dojo进行练习。我得到了结果,但由于某种原因,我的Haskell代码疯狂慢,我似乎无法找到错误。
这是我的代码(可以在Dojo的GitHub上找到.csv
个文件):
import Data.Char
import Data.List
import Data.List.Split
import Data.Ord
import System.IO
type Pixels = [Int]
data Digit = Digit { label :: Int, pixels :: Pixels }
distance :: Pixels -> Pixels -> Float
distance d1 d2 = sqrt . sum $ map pointDistance $ zip d1 d2
where pointDistance (a, b) = fromIntegral $ (a - b) * (a - b)
parseDigit :: String -> Digit
parseDigit s = Digit label pixels
where (label:pixels) = map read $ splitOn "," s
identify :: Digit -> [Digit] -> (Digit, Float)
identify digit training = minimumBy (comparing snd) distances
where distances = map fn training
fn ref = (ref, distance (pixels digit) (pixels ref))
readDigits :: String -> IO [Digit]
readDigits filename = do
fileContent <- readFile filename
return $ map parseDigit $ tail $ lines fileContent
main :: IO ()
main = do
trainingSample <- readDigits "trainingsample.csv"
validationSample <- readDigits "validationsample.csv"
let result = [(d, identify d trainingSample) | d <- validationSample]
fmt (d, (ref, dist)) = putStrLn $ "Found..."
mapM_ fmt result
这些糟糕表现的原因是什么?
[更新] 感谢您提出的许多想法!我已根据建议将String
的使用情况更改为Data.Text
,并将List
的使用情况更改为Data.Vector
,遗憾的是结果仍然不尽如人意。
我的更新代码为available here。
为了让您更好地理解我的审讯,这里是我的Haskell(左)和F#(右)实现的输出。我是这两种语言的全新手,所以我真诚地相信我的Haskell版本中存在一个重大错误,就是那么慢。
答案 0 :(得分:6)
如果您有耐心,您会注意到第二个结果的计算速度比第一个快得多。那是因为你的实现需要一些时间来读取csv文件。
您可能会想要粘贴一个打印语句,看看它何时完成加载:
main = do
trainingSample <- readDigits "trainingsample.csv"
validationSample <- readDigits "validationsample.csv"
putStrLn "done loading data"
但由于懒惰,这不会做你认为它做的事情。 trainingSample
和validationSample
尚未完全评估。所以你的print语句几乎会立即打印出来,第一个结果仍然需要永远。
您可以强制readDigits
完全评估他们的返回值,这样可以让您更好地了解在那里花了多少时间。您可以切换到使用非惰性IO,或只打印从数据派生的内容:
readDigits :: String -> IO [Digit]
readDigits filename = do
fileContent <- readFile filename
putStr' $ filename ++ ": "
rows <- forM (tail $ lines fileContent) $ \line -> do
let xs = parseDigit line
putStr' $ case compare (sum $ pixels xs) 0 of
LT -> "-"
EQ -> "0"
GT -> "+"
return xs
putStrLn ""
return rows
where putStr' s = putStr s >> hFlush stdout
在我的机器上,让我看到完全读取trainingsample.csv
的数字需要大约27秒。
这是printf风格的分析,它不是很好(使用真正的分析器要好得多,或者使用标准来对代码的各个部分进行基准测试),但是对于这些目的来说足够好。
这显然是经济放缓的主要部分,因此值得尝试转向严格的io。使用严格的Data.Text.IO.readFile
将其缩短到约18秒。
更新
以下是加快更新代码的方法:
使用未装箱的向量Pixels
(小赢):
import qualified Data.Vector.Unboxed as U
-- ...
type Pixels = U.Vector Int
-- ...
distance :: Pixels -> Pixels -> Float
distance d1 d2 = sqrt . U.sum $ U.zipWith pointDistance d1 d2
where pointDistance a b = fromIntegral $ (a - b) * (a - b)
parseDigit :: T.Text -> Digit
parseDigit s = Digit label (U.fromList pixels)
where (label:pixels) = map toDigit $ T.splitOn (T.pack ",") s
toDigit s = either (\_ -> 0) fst (T.Read.decimal s)
使用seq
(大赢)强制进行距离评估:
identify :: Digit -> V.Vector Digit -> (Digit, Float)
identify digit training = V.minimumBy (comparing snd) distances
where distances = V.map fn training
fn ref = let d = distance (pixels digit) (pixels ref) in d `seq` (ref, d)
在我的机器上,整个程序现在运行~5s:
% ghc --make -O2 Main.hs
[1 of 1] Compiling Main ( Main.hs, Main.o )
Linking Main ...
% time ./Main
./Main 5.00s user 0.11s system 99% cpu 5.115 total
th th杀了你。
答案 1 :(得分:3)
您的Vector版本,部分取消装箱,适用于ByteString并使用-O2 -fllvm
编译,在我的机器上运行8秒钟:
import Data.Ord
import Data.Maybe
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BC
type Pixels = U.Vector Int
data Digit = Digit { label :: !Int, pixels :: !Pixels }
distance :: Pixels -> Pixels -> Float
distance d1 d2 = sqrt . U.sum . U.zipWith pointDistance d1 $ d2
where pointDistance a b = fromIntegral $ (a - b) * (a - b)
parseDigit :: B.ByteString -> Digit
parseDigit bs =
let (label:pixels) = toIntegers bs []
in Digit label (U.fromList pixels)
where
toIntegers bs is =
let Just (i,bs') = BC.readInt bs
in if B.null bs' then reverse is else toIntegers (BC.tail bs') (i:is)
identify :: Digit -> V.Vector Digit -> (Digit, Float)
identify digit training = V.minimumBy (comparing snd) distances
where distances = V.map fn training
fn ref = (ref, distance (pixels digit) (pixels ref))
readDigits :: String -> IO (V.Vector Digit)
readDigits filename = do
fileContent <- B.readFile filename
return . V.map parseDigit . V.fromList . tail . BC.lines $ fileContent
main :: IO ()
main = do
trainingSample <- readDigits "trainingsample.csv"
validationSample <- readDigits "validationsample.csv"
let result = V.map (\d -> (d, identify d trainingSample)) validationSample
fmt (d, (ref, dist)) = putStrLn $ "Found " ++ show (label ref) ++ " for " ++ show (label d) ++ " (distance=" ++ show dist ++ ")"
V.mapM_ fmt result
+RTS -s
的输出:
989,632,984 bytes allocated in the heap
19,875,368 bytes copied during GC
31,016,504 bytes maximum residency (5 sample(s))
22,748,608 bytes maximum slop
78 MB total memory in use (1 MB lost due to fragmentation)
Tot time (elapsed) Avg pause Max pause
Gen 0 1761 colls, 0 par 0.05s 0.05s 0.0000s 0.0008s
Gen 1 5 colls, 0 par 0.00s 0.02s 0.0030s 0.0085s
INIT time 0.00s ( 0.00s elapsed)
MUT time 7.42s ( 7.69s elapsed)
GC time 0.05s ( 0.06s elapsed)
EXIT time 0.00s ( 0.01s elapsed)
Total time 7.47s ( 7.77s elapsed)
%GC time 0.7% (0.8% elapsed)
Alloc rate 133,419,569 bytes per MUT second
Productivity 99.3% of total user, 95.5% of total elapsed