有没有可以让我控制var args数量的函数?

时间:2019-06-26 21:18:34

标签: julia

我有以下代码:

circ(x) = x./sqrt(sum(x .* x))

x -> cat(circ(x), circ(x); dims = 1)

但是我希望能够创建一个函数,在该函数中输入数字并将其与circ(x)数连接。

例如:

function Ncircs(n)
  #some way to make cat() have as its parameter circ n number of times
end

我可以打电话给Ncircs(2)并得到 x -> cat(circ(x), circ(x); dims = 1)Ncircs(3)并获得 x -> cat(circ(x), circ(x), circ(x); dims = 1)Ncircs(4)并获得 x -> cat(circ(x), circ(x), circ(x), circ(x); dims = 1)

有没有办法做到这一点?我必须使用宏吗?

1 个答案:

答案 0 :(得分:3)

您可以写:

Ncircs(n) = x -> cat(Iterators.repeated(circ(x), n)...; dims = 1)

,如果您知道自己总是会做dims=1,请使用catvcat重新电镀reduce

Ncircs(n) = x -> reduce(vcat, Iterators.repeated(circ(x), n))

对于大型n,效率更高。

请注意:使用其他选项(vcat)会产生类型稳定的结果,而第一个选项不是类型稳定的。

编辑

为什么不允许减少空集合?

通常,原因是您无法确定减少的结果。如果要允许空集合,则应添加init关键字参数。这是一个示例:

julia> reduce(vcat, [])
ERROR: ArgumentError: reducing over an empty collection is not allowed

julia> reduce(vcat, [], init = [1])
1-element Array{Int64,1}:
 1

julia> reduce(vcat, [[2,3], [4,5]], init = [1])
5-element Array{Int64,1}:
 1
 2
 3
 4
 5

结果是类型稳定的是什么意思

这意味着Julia能够在编译时(在执行代码之前)告诉函数返回值的类型是什么。类型稳定的代码通常运行速度更快(尽管这是一个广泛的话题-我建议您阅读Julia手册以详细了解它)。您可以使用@code_warntypeTest.@inferred检查该函数的类型是否稳定。

在这里,请允许我针对您的具体情况进行说明(我将部分输出内容删节以缩短答案)。

julia> x = [1,2,3]
3-element Array{Int64,1}:
 1
 2
 3

julia> y = [4,5,6]
3-element Array{Int64,1}:
 4
 5
 6

julia> @code_warntype vcat(x,y)
Body::Array{Int64,1}
...

julia> @code_warntype cat(x,y, dims=1)
Body::Any
...

julia> using Test

julia> @inferred vcat(x,y)
6-element Array{Int64,1}:
 1
 2
 3
 4
 5
 6

julia> @inferred cat(x,y, dims=1)
ERROR: return type Array{Int64,1} does not match inferred return type Any
上面的

Any意味着编译器不知道答案的类型。在这种情况下,原因是此类型取决于dims参数。如果是1,它将是一个向量;如果是2,它将是一个矩阵。

我怎么知道对于大型n

您可以运行@which宏:

julia> @which reduce(vcat, [[1,2,3], [4,5,6]])
reduce(::typeof(vcat), A::AbstractArray{#s72,1} where #s72<:(Union{AbstractArray{T,2}, AbstractArray{T,1}} where T)) in Base at abstractarray.jl:1321

您会发现reduce有专门的vcat方法。

现在,如果您运行:

@edit reduce(vcat, [[1,2,3], [4,5,6]])

将打开一个编辑器,您会看到它调用了一个内部函数_typed_vcat,该函数针对vcat优化了很多数组。引入此优化的原因是,使用像这样的vcat([[1,2,3], [4,5,6]]...)这样的splatting在结果上是等效的,但是您必须进行splatting(...),它本身具有一些成本,使用{{1 }}版本。

为了确保我所说的是正确的,您可以执行以下基准测试:

reduce

您会发现julia> using BenchmarkTools julia> y = [[i] for i in 1:10000]; julia> @benchmark vcat($y...) BenchmarkTools.Trial: memory estimate: 156.45 KiB allocs estimate: 3 -------------- minimum time: 67.200 μs (0.00% GC) median time: 77.800 μs (0.00% GC) mean time: 102.804 μs (8.50% GC) maximum time: 35.179 ms (99.47% GC) -------------- samples: 10000 evals/sample: 1 julia> @benchmark reduce(vcat, $y) BenchmarkTools.Trial: memory estimate: 78.20 KiB allocs estimate: 2 -------------- minimum time: 67.700 μs (0.00% GC) median time: 69.700 μs (0.00% GC) mean time: 82.442 μs (6.39% GC) maximum time: 32.719 ms (99.58% GC) -------------- samples: 10000 evals/sample: 1 julia> @benchmark cat($y..., dims=1) ERROR: StackOverflowError: 版本比reduce的splatting版本要快一些,而vcat仅在很大的cat上失败(对于较小的n可以,但速度会慢一些。