如何使用GADT进行通用元组映射?

时间:2017-01-21 15:58:13

标签: haskell

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings #-}

module Main where

import qualified Data.ByteString.Char8 as B
import Data.ByteString (ByteString)

infixl 1 |>
(|>) = flip ($)

data CudaVarAr2d x where VarAr2d :: CudaVarScalar Int -> CudaVarScalar Int -> ByteString -> CudaVarAr2d x
data CudaVarAr1d x where VarAr1d :: CudaVarScalar Int -> ByteString -> CudaVarAr1d x
data CudaVarScalar x where VarScalar :: ByteString -> CudaVarScalar x

data CudaVariable x where
  VarAr2d' :: CudaVarAr2d x -> CudaVariable x
  VarAr1d' :: CudaVarAr1d x -> CudaVariable x
  VarScalar' :: CudaVarScalar x -> CudaVariable x
  VarTuple2 :: CudaVariable x -> CudaVariable y -> CudaVariable (x,y)
  VarTuple3 :: CudaVariable x -> CudaVariable y -> CudaVariable z -> CudaVariable (x,y,z)

size = VarScalar "size"
x1 = VarAr1d' $ VarAr1d size "x1"
x2 = VarAr1d' $ VarAr1d size "x2"
inp = VarTuple2 x1 x2

o1 = VarAr1d' $ VarAr1d size "o1"
o2 = VarAr1d' $ VarAr1d size "o2"
outp = VarTuple2 o1 o2

-- Later I intend to cover all the cases.
varar1d_into_prim_adj :: CudaVariable x -> CudaVariable (x,x)
varar1d_into_prim_adj (VarAr1d' (VarAr1d size name)) = VarTuple2 x1 x2 where
  f suffix = VarAr1d' (VarAr1d size ([name,suffix] |> B.concat))
  x1 = f "_primal"
  x2 = f "_adjoint"

--map_into_prim_adj :: CudaVariable x -> CudaVariable x
map_into_prim_adj x =
  let f = varar1d_into_prim_adj in
  case x of
    VarTuple2 a b -> VarTuple2 (f a) (f b)
    VarTuple3 a b c -> VarTuple3 (f a) (f b) (f c)

main :: IO ()
main = print "Hello"

当我为Cuda编译器创建一个接口时,这个带有元组的东西在F#-land中杀了我,所以我在Haskell-land中寻找更多的肥沃土壤。我不知道如何表达map_into_prim_adj的类型,我希望编译器会为我做,但事实并非如此。

这样做的最终目标是写出像

这样的东西
cuda_map_fb 
    (\(x,y) -> x * y)) 
    (\((x_primal,x_adjoint),(y_primal,y_adjoint)) error -> do
        set x_adjoint (x_adjoint + error * y_primal)
        set y_adjoint (y_adjoint + error * x_primal))

让它被类型检查并发送给编译器,但首先我需要弄清楚如何将元组先映射为原始和伴随。我要做的是用于自动差异化/深度学习库的微型嵌入式Cuda编译器。

1 个答案:

答案 0 :(得分:2)

我不确定您正在做什么,但以下是您使用TypeFamilies语言扩展程序键入map_into_prim_adj的方法:

type family ResType t :: * where
  ResType (x, y) = ((x, x), (y, y))
  ResType (x, y, z) = ((x, x), (y, y), (z, z))
  -- ...

map_into_prim_adj :: CudaVariable x -> CudaVariable (ResType x)
map_into_prim_adj x =
  let f = varar1d_into_prim_adj in
  case x of
    VarTuple2 a b -> VarTuple2 (f a) (f b)
    VarTuple3 a b c -> VarTuple3 (f a) (f b) (f c)
    -- ...

一点解释。类型族有点像从类型到类型的函数。在map_into_prim_adj x VarTuple2 x CudaVariable (a, b) ResType的类型为(a, b)ResType的参数为((a, a), (b, b)) {{1}} }这将匹配{{1}}的第一个等式,为我们提供输出类型{{1}}。