{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings #-}
module Main where
import qualified Data.ByteString.Char8 as B
import Data.ByteString (ByteString)
infixl 1 |>
(|>) = flip ($)
data CudaVarAr2d x where VarAr2d :: CudaVarScalar Int -> CudaVarScalar Int -> ByteString -> CudaVarAr2d x
data CudaVarAr1d x where VarAr1d :: CudaVarScalar Int -> ByteString -> CudaVarAr1d x
data CudaVarScalar x where VarScalar :: ByteString -> CudaVarScalar x
data CudaVariable x where
VarAr2d' :: CudaVarAr2d x -> CudaVariable x
VarAr1d' :: CudaVarAr1d x -> CudaVariable x
VarScalar' :: CudaVarScalar x -> CudaVariable x
VarTuple2 :: CudaVariable x -> CudaVariable y -> CudaVariable (x,y)
VarTuple3 :: CudaVariable x -> CudaVariable y -> CudaVariable z -> CudaVariable (x,y,z)
size = VarScalar "size"
x1 = VarAr1d' $ VarAr1d size "x1"
x2 = VarAr1d' $ VarAr1d size "x2"
inp = VarTuple2 x1 x2
o1 = VarAr1d' $ VarAr1d size "o1"
o2 = VarAr1d' $ VarAr1d size "o2"
outp = VarTuple2 o1 o2
-- Later I intend to cover all the cases.
varar1d_into_prim_adj :: CudaVariable x -> CudaVariable (x,x)
varar1d_into_prim_adj (VarAr1d' (VarAr1d size name)) = VarTuple2 x1 x2 where
f suffix = VarAr1d' (VarAr1d size ([name,suffix] |> B.concat))
x1 = f "_primal"
x2 = f "_adjoint"
--map_into_prim_adj :: CudaVariable x -> CudaVariable x
map_into_prim_adj x =
let f = varar1d_into_prim_adj in
case x of
VarTuple2 a b -> VarTuple2 (f a) (f b)
VarTuple3 a b c -> VarTuple3 (f a) (f b) (f c)
main :: IO ()
main = print "Hello"
当我为Cuda编译器创建一个接口时,这个带有元组的东西在F#-land中杀了我,所以我在Haskell-land中寻找更多的肥沃土壤。我不知道如何表达map_into_prim_adj
的类型,我希望编译器会为我做,但事实并非如此。
这样做的最终目标是写出像
这样的东西cuda_map_fb
(\(x,y) -> x * y))
(\((x_primal,x_adjoint),(y_primal,y_adjoint)) error -> do
set x_adjoint (x_adjoint + error * y_primal)
set y_adjoint (y_adjoint + error * x_primal))
让它被类型检查并发送给编译器,但首先我需要弄清楚如何将元组先映射为原始和伴随。我要做的是用于自动差异化/深度学习库的微型嵌入式Cuda编译器。
答案 0 :(得分:2)
我不确定您正在做什么,但以下是您使用TypeFamilies
语言扩展程序键入map_into_prim_adj
的方法:
type family ResType t :: * where
ResType (x, y) = ((x, x), (y, y))
ResType (x, y, z) = ((x, x), (y, y), (z, z))
-- ...
map_into_prim_adj :: CudaVariable x -> CudaVariable (ResType x)
map_into_prim_adj x =
let f = varar1d_into_prim_adj in
case x of
VarTuple2 a b -> VarTuple2 (f a) (f b)
VarTuple3 a b c -> VarTuple3 (f a) (f b) (f c)
-- ...
一点解释。类型族有点像从类型到类型的函数。在map_into_prim_adj
x
VarTuple2
x
CudaVariable (a, b)
ResType
的类型为(a, b)
,ResType
的参数为((a, a), (b, b))
{{1}} }这将匹配{{1}}的第一个等式,为我们提供输出类型{{1}}。