Julia的ForwardDiff可以处理闭包吗?如果不是,那么它就不是很有用了,但是如果是,那么我在下面哪里出错了?
using ForwardDiff
function make_add(x)
foo = y::Vector -> y+x
return foo
end
zulu = make_add(17)
g = x-> ForwardDiff.gradient(zulu, x)
g([1, 2, 3])
MethodError: no method matching extract_gradient!
(::Type{ForwardDiff.Tag{##1#2{Int64},Int64}},
`::Array{Array{ForwardDiff.Dual{ForwardDiff.Tag{##1#2{Int64},Int64},Int64,3},1},1}, ::Array{ForwardDiff.Dual{ForwardDiff.Tag{##1#2{Int64},Int64},Int64,3},1})
Closest candidates are:
extract_gradient!(::Type{T}, ::AbstractArray, ::ForwardDiff.Dual) where T at /home/jrun/.julia/v0.6/ForwardDiff/src/gradient.jl:76
extract_gradient!(::Type{T}, ::AbstractArray, ::Real) where T at /home/jrun/.julia/v0.6/ForwardDiff/src/gradient.jl:75
extract_gradient!(::Type{T}, ::DiffResults.DiffResult, ::ForwardDiff.Dual) where T at /home/jrun/.julia/v0.6/ForwardDiff/src/gradient.jl:70
...
Stacktrace:
[1] gradient(::Function, ::Array{Int64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{##1#2{Int64},Int64},Int64,3,Array{ForwardDiff.Dual{ForwardDiff.Tag{##1#2{Int64},Int64},Int64,3},1}}, ::Val{true}) at /home/jrun/.julia/v0.6/ForwardDiff/src/gradient.jl:17
[2] gradient(::Function, ::Array{Int64,1}, ::ForwardDiff.GradientConfig{ForwardDiff.Tag{##1#2{Int64},Int64},Int64,3,Array{ForwardDiff.Dual{ForwardDiff.Tag{##1#2{Int64},Int64},Int64,3},1}}) at /home/jrun/.julia/v0.6/ForwardDiff/src/gradient.jl:15
[3] (::##3#4)(::Array{Int64,1}) at ./In[8]:1`
编辑实际上,这与闭包无关。简单地:
h = x-> ForwardDiff.gradient(x-> x+17.0, x)
炸弹完全一样
答案 0 :(得分:1)
gradient
是为数组定义的。在标量上使用derivative
。
答案 1 :(得分:1)
ForwardDiff.gadient
的文档中指出:
此方法假定
isa(f(x), Real)
。
问题是您的函数返回的是向量而不是标量,因此您需要使用jacobian
(接受数组作为返回值):
julia> function make_add(x)
foo = y::Vector -> y .+ x
return foo
end
make_add (generic function with 1 method)
julia> zulu = make_add(17)
#27 (generic function with 1 method)
julia> g = x-> ForwardDiff.jacobian(zulu, x)
#29 (generic function with 1 method)
julia> g([1, 2, 3])
3×3 Array{Int64,2}:
1 0 0
0 1 0
0 0 1
还请注意,我在+
之前添加了一个点(因此其读为y .+ x
),因为在当前版本的Julia 1.0中,不允许在不广播的情况下向向量添加标量。 / p>