我有以下矩阵乘法函数
winograd :: M.Matrix Int -> M.Matrix Int -> M.Matrix Int
winograd a b = c
where
ra = M.nrows a
ca = M.ncols a
rb = M.nrows b
cb = M.ncols b
isEven = even ca
avs = V.generate ra $ \i -> M.getRow (i+1) a
bvs = V.generate cb $ \j -> M.getCol (j+1) b
rows = V.generate ra $ \i -> wgdGroup $ V.unsafeIndex avs i
cols = V.generate cb $ \j -> wgdGroup $ V.unsafeIndex bvs j
wgdGroup x = let
finish = (V.length x - 1)
in
forLoopFold 0 (<finish) (+2) 0 $
\acc i -> acc
- V.unsafeIndex x i * V.unsafeIndex x (i+1)
c = if isEven then
M.matrix ra cb $
\(i,j) ->
V.unsafeIndex rows (i-1) +
V.unsafeIndex cols (j-1) +
g (V.unsafeIndex avs (i-1)) (V.unsafeIndex bvs (j-1))
else
M.matrix ra cb $
\(i,j) ->
V.unsafeIndex rows (i-1) +
V.unsafeIndex cols (j-1) +
g (V.unsafeIndex avs (i-1)) (V.unsafeIndex bvs (j-1)) +
V.last (V.unsafeIndex avs (i-1)) * V.last (V.unsafeIndex bvs (j-1))
g r c = forLoopFold 0 (<(ca-1)) (+2) 0 $ \acc i ->
let
x1 = V.unsafeIndex r i
x2 = V.unsafeIndex r (i+1)
y1 = V.unsafeIndex c i
y2 = V.unsafeIndex c (i+1)
in
acc + (x1+y2)*(x2+y1)
我使用clock
和formatting
包来测量此函数的执行时间如下(是的,我知道Criterion,但现在我需要以这种方式测量它):
main :: IO ()
main = do
let a = simple 2 (-1) 1000
let b = simple 2 (-3) 1000
start <- getTime Realtime
let
c = winograd a b
end <- c `deepseq` getTime Realtime
fprint (timeSpecs % "\n") start end
其中simple
是
simple :: Int -> Int -> Int -> M.Matrix Int
simple x y size = M.matrix size size $ \(i,j) -> x*i+y*j
我得到的结果大约5秒钟。但是当我摆脱函数simple
并且这样做时:
main :: IO ()
main = do
let a = M.matrix 1000 1000 $ \(i,j) -> 2*i-1*j
let b = M.matrix 1000 1000 $ \(i,j) -> 2*i-3*j
start <- getTime Realtime
let
c = winograd a b
end <- c `deepseq` getTime Realtime
fprint (timeSpecs % "\n") start end
时间增加到15秒!
非常有趣的原因。用标志-O2编译。