我正在尝试从Numeric.AD编译以下最小示例:
import Numeric.AD
timeAndGrad f l = grad f l
main = putStrLn "hi"
我遇到了这个错误:
test.hs:3:24:
Couldn't match expected type ‘f (Numeric.AD.Internal.Reverse.Reverse
s a)
-> Numeric.AD.Internal.Reverse.Reverse s a’
with actual type ‘t’
because type variable ‘s’ would escape its scope
This (rigid, skolem) type variable is bound by
a type expected by the context:
Data.Reflection.Reifies s Numeric.AD.Internal.Reverse.Tape =>
f (Numeric.AD.Internal.Reverse.Reverse s a)
-> Numeric.AD.Internal.Reverse.Reverse s a
at test.hs:3:19-26
Relevant bindings include
l :: f a (bound at test.hs:3:15)
f :: t (bound at test.hs:3:13)
timeAndGrad :: t -> f a -> f a (bound at test.hs:3:1)
In the first argument of ‘grad’, namely ‘f’
In the expression: grad f l
有关为何发生这种情况的任何线索?从以前的例子中我可以看出,这是“平坦化”。 grad
的类型:
grad :: (Traversable f, Num a) => (forall s. Reifies s Tape => f (Reverse s a) -> Reverse s a) -> f a -> f a
但我实际上需要在我的代码中做这样的事情。实际上,这是一个不会编译的最小例子。我想要做的更复杂的事情是这样的:
example :: SomeType
example f x args = (do stuff with the gradient and gradient "function")
where gradient = grad f x
gradientFn = grad f
(other where clauses involving gradient and gradient "function")
这是一个稍微复杂的版本,带有可以编译的类型签名。
{-# LANGUAGE RankNTypes #-}
import Numeric.AD
import Numeric.AD.Internal.Reverse
-- compiles but I can't figure out how to use it in code
grad2 :: (Show a, Num a, Floating a) => (forall s.[Reverse s a] -> Reverse s a) -> [a] -> [a]
grad2 f l = grad f l
-- compiles with the right type, but the resulting gradient is all 0s...
grad2' :: (Show a, Num a, Floating a) => ([a] -> a) -> [a] -> [a]
grad2' f l = grad f' l
where f' = Lift . f . extractAll
-- i've tried using the Reverse constructor with Reverse 0 _, Reverse 1 _, and Reverse 2 _, but those don't yield the correct gradient. Not sure how the modes work
extractAll :: [Reverse t a] -> [a]
extractAll xs = map extract xs
where extract (Lift x) = x -- non-exhaustive pattern match
dist :: (Show a, Num a, Floating a) => [a] -> a
dist [x, y] = sqrt(x^2 + y^2)
-- incorrect output: [0.0, 0.0]
main = putStrLn $ show $ grad2' dist [1,2]
但是,我无法弄清楚如何在代码中使用第一个版本grad2
,因为我不知道如何处理Reverse s a
。第二个版本grad2'
具有正确的类型,因为我使用内部构造函数Lift
来创建Reverse s a
,但我不能理解内部结构(特别是参数{{1} }})工作,因为输出渐变全是0。使用其他构造函数s
(此处未显示)也会产生错误的渐变。
或者,是否存在人们使用Reverse
代码的库/代码示例?我认为我的用例非常普遍。
答案 0 :(得分:2)
使用where f' = Lift . f . extractAll
,你基本上会创建一个自动分化底层类型的后门,抛弃所有衍生物,只保留常量值。如果你将其用于grad
,那么得到零结果就不足为奇了!
明智的方法是直接使用grad
:
dist :: Floating a => [a] -> a
dist [x, y] = sqrt $ x^2 + y^2
-- preferrable is of course `dist = sqrt . sum . map (^2)`
main = print $ grad dist [1,2]
-- output: [0.4472135954999579,0.8944271909999159]
你真的不需要知道任何更复杂的使用自动差异化。只要您只区分Num
或Floating
- 多态函数,一切都将按原样运行。如果你需要区分作为参数传入的函数,你需要使该参数为rank-2多态(另一种方法是切换到ad
函数的rank-1版本,但我敢说不那么优雅,并没有真正获得你很多)。
{-# LANGUAGE Rank2Types, UnicodeSyntax #-}
mainWith :: (∀n . Floating n => [n] -> n) -> IO ()
mainWith f = print $ grad f [1,2]
main = mainWith dist