使用自定义数据类型自动区分

时间:2016-03-28 17:24:26

标签: haskell automatic-differentiation

我在尝试使用Haskell ad库区分自定义数据类型时遇到问题。有一个相关的问题here,它有所帮助,但我认为在这种情况下可能效率不高。

以下是我面临的问题的简化版本:

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveTraversable #-}

import Numeric.AD
import Data.Foldable
import Data.Traversable

data Sample a = Sample a deriving (Show, Traversable, Foldable, Functor)

data Model s a = Model {
  sample_logp :: s -> a
}

instance Functor (Model s) where
  fmap f m = Model {
    sample_logp = f . (sample_logp m)
  }

dtest :: (Num a) => Model (Sample a) a -> Sample a -> Sample a
dtest m x = grad (\x' -> test (fmap auto m) x') x

test :: Num a => Model (Sample a) a -> Sample a -> a
test m x = (sample_logp m)  x

这会导致以下错误消息:

Test.hs:22:42:
Could not deduce (a ~ Numeric.AD.Internal.Reverse.Reverse s a)
from the context (Num a)
  bound by the type signature for
             dtest :: Num a => Model (Sample a) a -> Sample a -> Sample a
  at Test.hs:21:10-62
or from (Data.Reflection.Reifies
           s Numeric.AD.Internal.Reverse.Tape)
  bound by a type expected by the context:
             Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
             Sample (Numeric.AD.Internal.Reverse.Reverse s a)
             -> Numeric.AD.Internal.Reverse.Reverse s a
  at Test.hs:22:13-49
  ‘a’ is a rigid type variable bound by
      the type signature for
        dtest :: Num a => Model (Sample a) a -> Sample a -> Sample a
      at Test.hs:21:10
Expected type: Model
                 (Sample (Numeric.AD.Internal.Reverse.Reverse s a)) a
  Actual type: Model (Sample a) a
Relevant bindings include
  x' :: Sample (Numeric.AD.Internal.Reverse.Reverse s a)
    (bound at Test.hs:22:20)
  x :: Sample a (bound at Test.hs:22:9)
  m :: Model (Sample a) a (bound at Test.hs:22:7)
  dtest :: Model (Sample a) a -> Sample a -> Sample a
    (bound at Test.hs:22:1)
In the second argument of ‘fmap’, namely ‘m’
In the first argument of ‘test’, namely ‘(fmap auto m)’

编辑:让这个与左下角的建议一起使用:

dtest :: forall a . (Num a) => (forall b . (Num b) => Model (Sample b) b) -> Sample a -> Sample a
dtest m x = grad go x
  where
    go :: forall t. (Scalar t ~ a, Mode t) => Sample t -> t
    go = test m

0 个答案:

没有答案