我使用的是GHC 8.0.1。 我编写了简单的矩阵乘法基准测试,我尝试了两种不同的矩阵数据类型方法:
1)Matrix包含4个4个浮点元素的向量。
2)Matrix包含16个浮点元素。
我认为这两个应该是相同的,因为在第一种情况下,矢量将被解压缩为矩阵数据类型。事实上,当我研究生成的Core时,似乎就是这种情况。此外,mult
函数的核心似乎是相同的。
然而,第二个版本比第一个版本慢2倍。
main.hs:
module Main where
import Criterion.Main
import Control.Monad
import System.Random.MWC
import qualified MatMul.Unrolled as U
import qualified MatMul.UnrolledFull as F
main :: IO ()
main = do
gen <- create
elems1 <- mkElems gen
elems2 <- mkElems gen
defaultMain
[ bgroup "matrix mult 4x4"
[ bench "unrolled" $ nf (U.mult (U.mkMat elems1)) (U.mkMat elems2)
, bench "unrolledFull" $ nf (F.mult (F.mkMat elems1)) (F.mkMat elems2) ]
]
where
-- see http://hackage.haskell.org/package/mwc-random-0.13.5.0/docs/System-Random-MWC.html#v:uniform
mkElems gen = replicateM 16 (fmap (\x -> x - (2 :: Float) ^^ (-33 :: Int)) (uniform gen :: IO Float))
unrolled.hs:
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
module MatMul.Unrolled where
import Control.DeepSeq
import GHC.Generics
data Vector4f = Vector4f
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
deriving (Show, Eq, Generic, NFData)
data Matrix4x4f = Matrix4x4f
{-# UNPACK #-} !Vector4f
{-# UNPACK #-} !Vector4f
{-# UNPACK #-} !Vector4f
{-# UNPACK #-} !Vector4f
deriving (Show, Eq, Generic, NFData)
mkMat :: [Float] -> Matrix4x4f
mkMat (a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:_) =
Matrix4x4f
(Vector4f a b c d)
(Vector4f e f g h)
(Vector4f i j k l)
(Vector4f m n o p)
mult :: Matrix4x4f -> Matrix4x4f -> Matrix4x4f
mult
(Matrix4x4f
(Vector4f a0 a1 a2 a3)
(Vector4f b0 b1 b2 b3)
(Vector4f c0 c1 c2 c3)
(Vector4f d0 d1 d2 d3))
(Matrix4x4f
(Vector4f x0 x1 x2 x3)
(Vector4f y0 y1 y2 y3)
(Vector4f z0 z1 z2 z3)
(Vector4f w0 w1 w2 w3)) =
Matrix4x4f
(Vector4f (a0*x0+a1*y0+a2*z0+a3*w0) (a0*x1+a1*y1+a2*z1+a3*w1) (a0*x2+a1*y2+a2*z2+a3*w2) (a0*x3+a1*y3+a2*z3+a3*w3))
(Vector4f (b0*x0+b1*y0+b2*z0+b3*w0) (b0*x1+b1*y1+b2*z1+b3*w1) (b0*x2+b1*y2+b2*z2+b3*w2) (b0*x3+b1*y3+b2*z3+b3*w3))
(Vector4f (c0*x0+c1*y0+c2*z0+c3*w0) (c0*x1+c1*y1+c2*z1+c3*w1) (c0*x2+c1*y2+c2*z2+c3*w2) (c0*x3+c1*y3+c2*z3+c3*w3))
(Vector4f (d0*x0+d1*y0+d2*z0+d3*w0) (d0*x1+d1*y1+d2*z1+d3*w1) (d0*x2+d1*y2+d2*z2+d3*w2) (d0*x3+d1*y3+d2*z3+d3*w3))
unrolledFull.hs:
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
module MatMul.UnrolledFull where
import Control.DeepSeq
import GHC.Generics
data Matrix4x4f = Matrix4x4f
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
{-# UNPACK #-} !Float
deriving (Show, Eq, Generic, NFData)
mkMat :: [Float] -> Matrix4x4f
mkMat (a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:_) =
Matrix4x4f a b c d e f g h i j k l m n o p
mult :: Matrix4x4f -> Matrix4x4f -> Matrix4x4f
mult
(Matrix4x4f a0 a1 a2 a3 b0 b1 b2 b3 c0 c1 c2 c3 d0 d1 d2 d3)
(Matrix4x4f x0 x1 x2 x3 y0 y1 y2 y3 z0 z1 z2 z3 w0 w1 w2 w3) =
Matrix4x4f
(a0*x0+a1*y0+a2*z0+a3*w0) (a0*x1+a1*y1+a2*z1+a3*w1) (a0*x2+a1*y2+a2*z2+a3*w2) (a0*x3+a1*y3+a2*z3+a3*w3)
(b0*x0+b1*y0+b2*z0+b3*w0) (b0*x1+b1*y1+b2*z1+b3*w1) (b0*x2+b1*y2+b2*z2+b3*w2) (b0*x3+b1*y3+b2*z3+b3*w3)
(c0*x0+c1*y0+c2*z0+c3*w0) (c0*x1+c1*y1+c2*z1+c3*w1) (c0*x2+c1*y2+c2*z2+c3*w2) (c0*x3+c1*y3+c2*z3+c3*w3)
(d0*x0+d1*y0+d2*z0+d3*w0) (d0*x1+d1*y1+d2*z1+d3*w1) (d0*x2+d1*y2+d2*z2+d3*w2) (d0*x3+d1*y3+d2*z3+d3*w3)
我得到以下结果:
Running 1 benchmarks...
Benchmark matMul: RUNNING...
benchmarking matrix mult 4x4/unrolled
time 60.32 ns (60.21 ns .. 60.43 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 60.40 ns (60.28 ns .. 60.55 ns)
std dev 457.6 ps (352.8 ps .. 612.8 ps)
benchmarking matrix mult 4x4/unrolledFull
time 137.6 ns (137.4 ns .. 137.8 ns)
1.000 R² (1.000 R² .. 1.000 R²)
mean 137.7 ns (137.5 ns .. 138.0 ns)
std dev 874.8 ps (665.5 ps .. 1.194 ns)
Benchmark matMul: FINISH