我在尝试使用Haskell ad库区分自定义数据类型时遇到问题。有一个相关的问题here,它有所帮助,但我认为在这种情况下可能效率不高。
以下是我面临的问题的简化版本:
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}
import Numeric.AD
import Data.Foldable
import Data.Traversable
data Sample a = Sample a deriving (Show, Traversable, Foldable, Functor)
data Model s a = Model {
sample_logp :: s -> a
}
instance Functor (Model s) where
fmap f m = Model {
sample_logp = f . (sample_logp m)
}
dtest :: (Num a) => Model (Sample a) a -> Sample a -> Sample a
dtest m x = grad (\x' -> test (fmap auto m) x') x
test :: Num a => Model (Sample a) a -> Sample a -> a
test m x = (sample_logp m) x
这会导致以下错误消息:
Test.hs:22:42:
Could not deduce (a ~ Numeric.AD.Internal.Reverse.Reverse s a)
from the context (Num a)
bound by the type signature for
dtest :: Num a => Model (Sample a) a -> Sample a -> Sample a
at Test.hs:21:10-62
or from (Data.Reflection.Reifies
s Numeric.AD.Internal.Reverse.Tape)
bound by a type expected by the context:
Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
Sample (Numeric.AD.Internal.Reverse.Reverse s a)
-> Numeric.AD.Internal.Reverse.Reverse s a
at Test.hs:22:13-49
‘a’ is a rigid type variable bound by
the type signature for
dtest :: Num a => Model (Sample a) a -> Sample a -> Sample a
at Test.hs:21:10
Expected type: Model
(Sample (Numeric.AD.Internal.Reverse.Reverse s a)) a
Actual type: Model (Sample a) a
Relevant bindings include
x' :: Sample (Numeric.AD.Internal.Reverse.Reverse s a)
(bound at Test.hs:22:20)
x :: Sample a (bound at Test.hs:22:9)
m :: Model (Sample a) a (bound at Test.hs:22:7)
dtest :: Model (Sample a) a -> Sample a -> Sample a
(bound at Test.hs:22:1)
In the second argument of ‘fmap’, namely ‘m’
In the first argument of ‘test’, namely ‘(fmap auto m)’
编辑:让这个与左下角的建议一起使用:
dtest :: forall a . (Num a) => (forall b . (Num b) => Model (Sample b) b) -> Sample a -> Sample a
dtest m x = grad go x
where
go :: forall t. (Scalar t ~ a, Mode t) => Sample t -> t
go = test m