我正在编写一个函数来计算重心插值公式的权重。忽略类型稳定性,这很容易:
function baryweights(x)
n = length(x)
if n == 1; return [1.0]; end # This is obviously not type stable
xmin,xmax = extrema(x)
x *= 4/(xmax-xmin)
# ^ Multiply by capacity of interval to avoid overflow
return [
1/prod(x[i]-x[j] for j in 1:n if j != i)
for i = 1:n
]
end
类型稳定性的问题是计算出n > 1
大小写的返回类型,这样我就可以在n == 1
的情况下返回正确类型的数组。有没有一个简单的技巧来实现这一目标?
答案 0 :(得分:2)
我不确定我是否理解你的计划。但也许这样的事情会有所帮助吗? - >
baryone(t::T) where T<:Real = [1.]
baryone(t::T) where T<:Complex = [1im] # or whatever you like here
function baryweights(x::Array{T,1}) where T<:Number
n = length(x)
n == 1 && return baryone(x[1])
xmin,xmax = extrema(x) # don't forget fix extrema for complex! :)
x *= 4/(xmax-xmin)
# ^ Multiply by capacity of interval to avoid overflow
return [
1/prod(x[i]-x[j] for j in 1:n if j != i)
for i = 1:n
]
end
警告:我还是新手!如果我尝试@code_warntype baryweights([1])
,我会看到很多警告。 (如果我不打电话baryone
)。例如n
是Any
!!
编辑:
我asked on discourse现在看到如果我们使用另一个变量(y),@code_warn
会返回更好的结果:
function baryweights(x::Array{T,1}) where T<:Number
n = length(x)
n == 1 && return baryone(x[1])
xmin,xmax = extrema(x) # don't forget fix extrema for complex! :)
let y = x * 4/(xmax-xmin)
# ^ Multiply by capacity of interval to avoid overflow
return [
1/prod(y[i]-y[j] for j in 1:n if j != i)
for i = 1:n
]
end
end
Edit2:我添加了let
以避免y
被Core.Box
编辑
答案 1 :(得分:1)
只需在伪参数上递归调用该函数:
function baryweights(x)
n = length(x)
if n == 1
T = eltype(baryweights(zeros(eltype(x),2)))
return [one(T)]
end
xmin,xmax = extrema(x)
let x = 4/(xmax-xmin) * x
# ^ Multiply by capacity of interval to avoid overflow,
# and wrap in let to avoid another source of type instability
# (https://github.com/JuliaLang/julia/issues/15276)
return [
1/prod(x[i]-x[j] for j in 1:n if j != i)
for i = 1:n
]
end
end