包含区分大小的函数的类型稳定性

时间:2017-11-17 08:40:01

标签: julia generic-programming

我正在编写一个函数来计算重心插值公式的权重。忽略类型稳定性,这很容易:

function baryweights(x)
    n = length(x)
    if n == 1; return [1.0]; end # This is obviously not type stable

    xmin,xmax = extrema(x)
    x *= 4/(xmax-xmin)
    # ^ Multiply by capacity of interval to avoid overflow
    return [
        1/prod(x[i]-x[j] for j in 1:n if j != i)
        for i = 1:n
    ]
end

类型稳定性的问题是计算出n > 1大小写的返回类型,这样我就可以在n == 1的情况下返回正确类型的数组。有没有一个简单的技巧来实现这一目标?

2 个答案:

答案 0 :(得分:2)

我不确定我是否理解你的计划。但也许这样的事情会有所帮助吗? - >

baryone(t::T) where T<:Real = [1.]
baryone(t::T) where T<:Complex = [1im]  # or whatever you like here

function baryweights(x::Array{T,1}) where T<:Number
    n = length(x)
    n == 1 && return baryone(x[1])
    xmin,xmax = extrema(x)  # don't forget fix extrema for complex! :)
    x *= 4/(xmax-xmin)
    # ^ Multiply by capacity of interval to avoid overflow
    return [
        1/prod(x[i]-x[j] for j in 1:n if j != i)
        for i = 1:n
    ]
end

警告:我还是新手!如果我尝试@code_warntype baryweights([1]),我会看到很多警告。 (如果我不打电话baryone)。例如nAny !!

编辑: 我asked on discourse现在看到如果我们使用另一个变量(y),@code_warn会返回更好的结果:

function baryweights(x::Array{T,1}) where T<:Number
    n = length(x)
    n == 1 && return baryone(x[1])
    xmin,xmax = extrema(x)  # don't forget fix extrema for complex! :)
    let y = x * 4/(xmax-xmin)
        # ^ Multiply by capacity of interval to avoid overflow
        return [
            1/prod(y[i]-y[j] for j in 1:n if j != i)
            for i = 1:n
        ]
    end
end

Edit2:我添加了let以避免yCore.Box编辑

答案 1 :(得分:1)

只需在伪参数上递归调用该函数:

function baryweights(x)
    n = length(x)
    if n == 1
        T = eltype(baryweights(zeros(eltype(x),2)))
        return [one(T)]
    end

    xmin,xmax = extrema(x)
    let x = 4/(xmax-xmin) * x
        # ^ Multiply by capacity of interval to avoid overflow,
        #   and wrap in let to avoid another source of type instability
        #   (https://github.com/JuliaLang/julia/issues/15276)
        return [
            1/prod(x[i]-x[j] for j in 1:n if j != i)
            for i = 1:n
        ]
    end
end