在Julia

时间:2019-05-06 18:15:04

标签: julia automatic-differentiation

我正在尝试在一个库中使用ForwardDiff,在该库中,几乎所有功能都只能在Float中使用。我想对这些函数签名进行概括,以便可以在仍然具有足够的限制性的同时使用ForwardDiff,以便函数仅采用数字值,而不采用日期之类的东西。我有很多名称相同但类型不同的函数(即,以相同的函数名称将“时间”作为浮点数或日期作为函数的函数),并且不想在整个过程中都删除类型限定符。

最小工作示例

using ForwardDiff
x = [1.0, 2.0, 3.0, 4.0 ,5.0]
typeof(x) # Array{Float64,1}
function G(x::Array{Real,1})
    return sum(exp.(x))
end
function grad_F(x::Array)
  return ForwardDiff.gradient(G, x)
end
G(x) # Method Error
grad_F(x) # Method error

function G(x::Array{Float64,1})
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This has a method error

function G(x)
    return sum(exp.(x))
end
G(x) # This works
grad_F(x) # This works
# But now I cannot restrict the function G to only take numeric arrays and not for instance arrays of Dates.

是否有一种方法可以限制函数只接受数字值(Ints和Floats)以及ForwardDiff使用的任何双数结构,但不允许使用符号,日期等。

2 个答案:

答案 0 :(得分:2)

ForwardDiff.Dual是抽象类型Real的子类型。但是,您遇到的问题是Julia的类型参数是不变的,而不是协变的。然后,以下内容将返回false。

# check if `Array{Float64, 1}` is a subtype of `Array{Real, 1}`
julia> Array{Float64, 1} <: Array{Real, 1}
false

这使您可以定义函数

function G(x::Array{Real,1})
    return sum(exp.(x))
end

不正确(不适合您使用)。这就是为什么您会收到以下错误。

julia> G(x)
ERROR: MethodError: no method matching G(::Array{Float64,1})

正确的定义应该是

function G(x::Array{<:Real,1})
    return sum(exp.(x))
end

或者如果您以某种方式需要轻松访问数组的具体元素类型

 function G(x::Array{T,1}) where {T<:Real}
     return sum(exp.(x))
 end

您的grad_F函数也是如此。

您可能会发现阅读the relevant section有关类型的Julia文档很有用。


您可能还想为AbstractArray{<:Real,1}类型而不是Array{<:Real, 1}类型的函数添加注释,以便您的函数可以处理其他类型的数组,例如StaticArraysOffsetArrays等,而无需重新定义。

答案 1 :(得分:1)

这将接受由任何类型的数字参数化的任何类型的数组:

function foo(xs::AbstractArray{<:Number})
  @show typeof(xs)
end

或:

function foo(xs::AbstractArray{T}) where T<:Number
  @show typeof(xs)
end

如果需要在主体函数中引用类型参数T

x1 = [1.0, 2.0, 3.0, 4.0 ,5.0]
x2 = [1, 2, 3,4, 5]
x3 = 1:5
x4 = 1.0:5.0
x5 = [1//2, 1//4, 1//8]

xss = [x1, x2, x3, x4, x5]

function foo(xs::AbstractArray{T}) where T<:Number
  @show xs typeof(xs) T
  println()
end

for xs in xss
  foo(xs)
end

输出:

xs = [1.0, 2.0, 3.0, 4.0, 5.0]
typeof(xs) = Array{Float64,1}
T = Float64

xs = [1, 2, 3, 4, 5]
typeof(xs) = Array{Int64,1}
T = Int64

xs = 1:5
typeof(xs) = UnitRange{Int64}
T = Int64

xs = 1.0:1.0:5.0
typeof(xs) = StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}
T = Float64

xs = Rational{Int64}[1//2, 1//4, 1//8]
typeof(xs) = Array{Rational{Int64},1}
T = Rational{Int64}

您可以在此处运行示例代码:https://repl.it/@SalchiPapa/Restricting-function-signatures-in-Julia