我在Haskell上实现了Winograd算法,并试图改进它,决定让它更严格。我做到了这一点,但我不明白为什么它开始更快地工作。 Yesterday我问了一个类似的问题,但发布的代码不太正确。
由于该代码没有显示问题,我完全发布了代码。
module Main where
import qualified Data.Vector as V
import qualified Data.Matrix as M
import Control.DeepSeq
import Control.Exception
import System.Clock
import System.Mem
import Data.Time
matrixCtor :: Int -> Int -> Int -> M.Matrix Int
matrixCtor x y size = M.matrix size size $ \(i,j) -> x*i+y*j
winogradLazy :: M.Matrix Int -> M.Matrix Int -> M.Matrix Int
winogradLazy a b = c
where
n = M.nrows a
p = M.ncols a
m = M.ncols b
-- Translate into vectors, since indexing in matrices takes longer.
-- Matrix b is also transposed
a' = V.generate n $ \i -> M.getRow (i+1) a
bt' = V.generate m $ \j -> M.getCol (j+1) b
rows = V.generate n $ \i -> group $ V.unsafeIndex a' i
cols = V.generate m $ \j -> group $ V.unsafeIndex bt' j
group x = foldl (groupHelper x) 0 [0,2..p-1]
groupHelper x acc i = let
x1 = V.unsafeIndex x (i)
x2 = V.unsafeIndex x (i+1)
in
acc - x1 * x2
c = M.matrix n m $ \(i,j) ->
let
a = V.unsafeIndex rows (i-1) + V.unsafeIndex cols (j-1)
b = wsum (V.unsafeIndex a' (i-1)) (V.unsafeIndex bt' (j-1))
in
a + b
wsum r c = foldl (wsumHelper r c) 0 [0,2..p-1]
wsumHelper r c acc i = let
x1 = V.unsafeIndex r (i)
x2 = V.unsafeIndex r (i+1)
y1 = V.unsafeIndex c (i)
y2 = V.unsafeIndex c (i+1)
in
acc +(x1+y2)*(x2+y1)
winogradStrict :: M.Matrix Int -> M.Matrix Int -> M.Matrix Int
winogradStrict a b = c
where
n = M.nrows a
p = M.ncols a
m = M.ncols b
-- Translate into vectors, since indexing in matrices takes longer.
-- Matrix b is also transposed
a' = V.generate n $ \i -> M.getRow (i+1) a
bt' = V.generate m $ \j -> M.getCol (j+1) b
rows = V.generate n $ \i -> group $ V.unsafeIndex a' i
cols = V.generate m $ \j -> group $ V.unsafeIndex bt' j
group x = foldl (groupHelper x) 0 [0,2..p-1]
groupHelper x acc i = let
x1 = V.unsafeIndex x (i)
x2 = V.unsafeIndex x (i+1)
in
acc - x1 * x2
c = a' `deepseq` bt' `deepseq` M.matrix n m $ \(i,j) ->
let
a = V.unsafeIndex rows (i-1) + V.unsafeIndex cols (j-1)
b = wsum (V.unsafeIndex a' (i-1)) (V.unsafeIndex bt' (j-1))
in
a + b
wsum r c = foldl (wsumHelper r c) 0 [0,2..p-1]
wsumHelper r c acc i = let
x1 = V.unsafeIndex r (i)
x2 = V.unsafeIndex r (i+1)
y1 = V.unsafeIndex c (i)
y2 = V.unsafeIndex c (i+1)
in
acc + (x1+y2)*(x2+y1)
lazyTest :: IO ()
lazyTest = do
let a = matrixCtor 2 (-1) 500
let b = matrixCtor 2 (-3) 500
evaluate $ force a
evaluate $ force b
start <- getCurrentTime
let c = winogradLazy a b
evaluate $ force c
end <- getCurrentTime
print (diffUTCTime end start)
strictTest :: IO ()
strictTest = do
let a = matrixCtor 2 (-1) 500
let b = matrixCtor 2 (-3) 500
evaluate $ force a
evaluate $ force b
start <- getCurrentTime
let c = winogradStrict a b
evaluate $ force c
end <- getCurrentTime
print (diffUTCTime end start)
main :: IO ()
main = do
performMajorGC
lazyTest
performMajorGC
strictTest
在计算矩阵c
之前的严格版本中,我执行以下操作:
a' `deepseq` bt' `deepseq`
因此我得到以下结果
2.083201s --lazyTest
0.613508s --strictTest