朱莉娅在函数调用上的不安

时间:2019-05-03 14:10:45

标签: julia ambiguity

我有这个错误

ERROR: MethodError: vcat(::Array{Real,2}, ::TrackedArray{…,Array{Float32,2}}) is ambiguous. Candidates:
  vcat(364::AbstractArray, x::Union{TrackedArray, TrackedReal}, xs::Union{Number, AbstractArray}...) in Tracker at C:\Users\Henri\.julia\packages\Tracker\6wcYJ\src\lib\array.jl:167
  vcat(A::Union{AbstractArray{T,2}, AbstractArray{T,1}} where T...) in Base at abstractarray.jl:1296
Possible fix, define
  vcat(::Union{AbstractArray{T,2}, AbstractArray{T,1}} where T, ::Union{TrackedArray{T,1,A} where A<:AbstractArray{T,1} where T, TrackedArray{T,2,A} where A<:AbstractArray{T,2} where T}, ::Vararg{Union{AbstractArray{T,2}, AbstractArray{T,1}} where T,N} where N)

告诉我两个vcat()函数是不明确的。我想使用Base.vcat()函数,但是显式使用它会引发相同的错误。这是为什么 ?错误抛出所提出的“可能的解决方法”是什么?

此外,当我手动调用REPL中的每一行时,不会引发任何错误。我不了解这种行为。仅当vcat()在另一个函数内部调用的函数中时,才会发生这种情况。就像下面的示例一样。

以下是重现该错误的代码:

using Flux

function loss(a, b, net, net2)
    net2(vcat(net(a),a))

end

function test()    
    opt = ADAM()
    net = Chain(Dense(3,3))
    net2 = Chain(Dense(6,1))
    L(a, b) = loss(a, b, net, net2)

    data = tuple(rand(3,1), rand(3,1))
    xs = Flux.params(net)
    gs = Tracker.gradient(() -> L(data...), xs)
    Tracker.update!(opt, xs, gs)
end

1 个答案:

答案 0 :(得分:0)

正如Henri.D在评论中所提到的,我们设法通过谨慎处理a类型,即Array的{​​{1}}类型,默认类型为Float64,而rand返回了net(a)的{​​{1}},并且无法通过TrackedArray Float32来使用。

  

我设法通过更改vcat的损失函数来修复a,因为vcat不能连接为net2(vcat(net(a),Float32.(a)))vcatnet(a)一个Float32 Array。那么a是1个元素中的Float64,而我认为您需要一个L(data...),这就是为什么我最终用TrackedArray取代Float32