使用列表优雅地实现n维矩阵乘法?

时间:2015-07-02 22:17:28

标签: haskell functional-programming linear-algebra

列表函数允许我们非常优雅地实现任意维向量数学。例如:

on   = (.) . (.)
add  = zipWith (+)
sub  = zipWith (-)
mul  = zipWith (*)
dist = len `on` sub
dot  = sum `on` mul
len  = sqrt . join dot

等等。

main = print $ add [1,2,3] [1,1,1] -- [2,3,4]
main = print $ len [1,1,1]         -- 1.7320508075688772
main = print $ dot [2,0,0] [2,0,0] -- 4

当然,这不是最有效的解决方案,但是要有洞察力,因为可以说mapzipWith这样可以概括那些向量运算。但是,有一个我无法优雅地实现的功能 - 即交叉产品。由于交叉积的可能n维推广是nd矩阵行列式,如何优雅地实现矩阵乘法?

编辑:是的,我问了一个与我设置的问题完全无关的问题。 FML。

1 个答案:

答案 0 :(得分:4)

就是这样,我有一些代码可以进行n维矩阵运算,我认为至少在编写时非常可爱:

{-# LANGUAGE NoMonomorphismRestriction #-}
module MultiArray where

import Control.Arrow
import Control.Monad
import Data.Ix
import Data.Maybe

import Data.Array (Array)
import qualified Data.Array as A

-- {{{ from Dmwit.hs
deleteAt n   xs = take n xs ++ drop (n + 1) xs
insertAt n x xs = take n xs ++ x : drop n xs

doublify f g xs ys = f (uncurry g) (zip xs ys)
any2 = doublify any
all2 = doublify all
-- }}}

-- makes the most sense when ls and hs have the same length
instance Ix a => Ix [a] where
    range     = sequence . map range . uncurry zip
    inRange   = all2 inRange . uncurry zip
    rangeSize = product . uncurry (zipWith (curry rangeSize))

    index (ls, hs) xs = fst . foldr step (0, 1) $ zip indices sizes where
        indices = zipWith index (zip ls hs) xs
        sizes   = map rangeSize $ zip ls hs
        step (i, b) (s, p) = (s + p * i, p * b)

fold :: (Enum i, Ix i) => ([a] -> b) -> Int -> Array [i] a -> Array [i] b
fold f n a = A.array newBound assocs where
    (oldLowBound, oldHighBound) = A.bounds a
    (newLowBoundBeg , dimLow : newLowBoundEnd ) = splitAt n oldLowBound
    (newHighBoundBeg, dimHigh: newHighBoundEnd) = splitAt n oldHighBound
    assocs   = [(beg ++ end, f [a A.! (beg ++ i : end) | i <- [dimLow..dimHigh]])
               | beg <- range (newLowBoundBeg, newHighBoundBeg)
               , end <- range (newLowBoundEnd, newHighBoundEnd)
               ]
    newBound = (newLowBoundBeg ++ newLowBoundEnd, newHighBoundBeg ++ newHighBoundEnd)

flatten a = check a >> return value where
    check = guard . (1==) . length . fst . A.bounds
    value = A.ixmap ((head *** head) . A.bounds $ a) return a

elementWise :: (MonadPlus m, Ix i) => (a -> b -> c) -> Array i a -> Array i b -> m (Array i c)
elementWise f a b = check >> return value where
    check = guard $ A.bounds a == A.bounds b
    value = A.listArray (A.bounds a) (zipWith f (A.elems a) (A.elems b))

unsafeFlatten       a   = fromJust $ flatten       a
unsafeElementWise f a b = fromJust $ elementWise f a b

matrixMult a b = fold sum 1 $ unsafeElementWise (*) a' b' where
    aBounds = (join (***) (!!0)) $ A.bounds a
    bBounds = (join (***) (!!1)) $ A.bounds b
    a' = copy 2 bBounds a
    b' = copy 0 aBounds b

bijection f g a = A.ixmap ((f *** f) . A.bounds $ a) g a
unFlatten       = bijection return head
matrixTranspose = bijection reverse reverse
copy n (low, high) a = A.ixmap (newBounds a) (deleteAt n) a where
    newBounds = (insertAt n low *** insertAt n high) . A.bounds

这里的可爱位是matrixMult,这是专门用于二维数组的唯一操作之一。它沿着一个维度扩展其第一个参数(通过将二维对象的副本放入三维对象的每个切片中);沿着另一个扩展其第二个;做逐点乘法(现在是三维数组);然后通过求和折叠制造的第三维度。非常好。