Haskell GADTs - 为黎曼几何体制作类型安全的Tensor类型

时间:2017-04-01 12:16:31

标签: haskell type-safety gadt data-kinds

我想使用GADT在Haskell中进行Tensor演算的类型安全实现,因此规则是:

  1. 张量是n维度的矩阵,其中包含的内容可能是楼上的'或楼下'例如:enter image description here - 是一个没有任何差异的标量(标量),enter image description here是一个Tensor,其中有一个在楼上' index,enter image description here是一个有一堆楼上的张量'和楼下的' indecies
  2. 您可以添加相同类型的张量,这意味着它们具有相同的indecies签名。第一张量的第0个索引与第二张量的第0个索引属于同一类型(楼上或楼下),依此类推......

    enter image description here ~~~~确定

    enter image description here ~~~~不行

  3. 您可以使用MULTIPLY张量并获得更大的张量,并将这些内容连接起来:enter image description here

  4. 所以我希望Haskell的类型检查器不允许我编写那些不遵循这些规则的代码,否则就不会编译。

    以下是我尝试使用GADT的方法:

    {-# LANGUAGE GADTs #-}
    {-# LANGUAGE DataKinds #-}
    {-# LANGUAGE ExistentialQuantification #-}
    {-# LANGUAGE TypeOperators #-}
    
    data Direction = T | X | Y | Z
    data Index = Zero | Up Index | Down Index deriving (Eq, Show)
    
    plus :: Index -> Index -> Index
    plus Zero x = x
    plus (Up x) y = Up (plus x y)
    plus (Down x) y = Down (plus x y)
    
    data Tensor a = (a ~ Zero) => Scalar Double | 
                    forall b. (a ~ Up b) => Cov (Direction -> Tensor b) |
                    forall b. (a ~ Down b) => Con (Direction -> Tensor b) 
    
    add :: Tensor a -> Tensor a -> Tensor a
    add (Scalar x) (Scalar y) = (Scalar (x + y))
    add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
    add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))
    
    mul :: Tensor a -> Tensor b -> Tensor (plus a b)
    mul (Scalar x) (Scalar y) = (Scalar (x*y))
    mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
    mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
    mul (Cov f) y = (Cov (\d -> mul (f d) y))
    mul (Con f) y = (Con (\d -> mul (f d) y))
    

    但我得到了:

    Couldn't match type 'Down with `plus ('Down b1)'                                                                                                                                                                                                    
        Expected type: Tensor (plus a b)                                                                                                                                                                                                                    
          Actual type: Tensor ('Down b)                                                                                                                                                                                                                     
        Relevant bindings include                                                                                                                                                                                                                           
          f :: Direction -> Tensor b1 (bound at main.hs:28:10)                                                                                                                                                                                              
          mul :: Tensor a -> Tensor b -> Tensor (plus a b)                                                                                                                                                                                                  
            (bound at main.hs:24:1)                                                                                                                                                                                                                         
        In the expression: (Con (\ d -> mul (f d) y))                                                                                                                                                                                                       
        In an equation for `mul':                                                                                                                                                                                                                           
            mul (Con f) y = (Con (\ d -> mul (f d) y)) 
    

    有什么问题?

1 个答案:

答案 0 :(得分:3)

plus is just a function on values of type Index

>>> plus Zero Zero
Zero
>>> plus Zero (Up Zero)
Up Zero

so it can't appear in a type signature, as things are. You want to use the 'promoted' type where Zero, Up Zero etc. are types. Then you can write a type function and everything compiles.

{-# LANGUAGE GADTs #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}

data Direction = T | X | Y | Z
data Index = Zero | Up Index | Down Index deriving (Eq, Show)

-- type function Plus
type family Plus (i :: Index) (j :: Index) :: Index where
  Plus Zero x = x
  Plus (Up x) y  = Up (Plus x y)
  Plus (Down x) y = Down (Plus x y)

-- value fuction plus
plus :: Index -> Index -> Index
plus Zero x = x
plus (Up x) y = Up (plus x y)
plus (Down x) y = Down (plus x y)

data Tensor (a :: Index) where
  Scalar :: Double -> Tensor Zero
  Cov :: (Direction -> Tensor b) -> Tensor (Up b)
  Con :: (Direction -> Tensor b) -> Tensor (Down b)

add :: Tensor a -> Tensor a -> Tensor a
add (Scalar x) (Scalar y) = (Scalar (x + y))
add (Cov f) (Cov g) = (Cov (\d -> add (f d) (g d)))
add (Con f) (Con g) = (Con (\d -> add (f d) (g d)))

mul :: Tensor a -> Tensor b -> Tensor (Plus a b)
mul (Scalar x) (Scalar y) = (Scalar (x*y))
mul (Scalar x) (Cov f) = (Cov (\d -> mul (Scalar x) (f d)))
mul (Scalar x) (Con f) = (Con (\d -> mul (Scalar x) (f d)))
mul (Cov f) y = (Cov (\d -> mul (f d) y))
mul (Con f) y = (Con (\d -> mul (f d) y))

There was no ambiguity in Plus but I could have use the disambiguating tick ' to signal that I was dealing with the type level Zero, Up etc.

type family Plus (i :: Index) (j :: Index) :: Index where
  Plus 'Zero x = x
  Plus ('Up x) y  = 'Up (Plus x y)
  Plus ('Down x) y = 'Down (Plus x y)

TypeOperators would permit you to write a + b rather than Plus a b above.

type family (i :: Index) + (j :: Index) :: Index where
  Zero + x = x
  Up x + y  = Up (x + y)
  Down x + y = Down (x + y)