我有以下代码:
circ(x) = x./sqrt(sum(x .* x))
x -> cat(circ(x), circ(x); dims = 1)
但是我希望能够创建一个函数,在该函数中输入数字并将其与circ(x)数连接。
例如:
function Ncircs(n)
#some way to make cat() have as its parameter circ n number of times
end
我可以打电话给Ncircs(2)
并得到
x -> cat(circ(x), circ(x); dims = 1)
或Ncircs(3)
并获得
x -> cat(circ(x), circ(x), circ(x); dims = 1)
或Ncircs(4)
并获得
x -> cat(circ(x), circ(x), circ(x), circ(x); dims = 1)
等
有没有办法做到这一点?我必须使用宏吗?
答案 0 :(得分:3)
您可以写:
Ncircs(n) = x -> cat(Iterators.repeated(circ(x), n)...; dims = 1)
,如果您知道自己总是会做dims=1
,请使用cat
和vcat
重新电镀reduce
Ncircs(n) = x -> reduce(vcat, Iterators.repeated(circ(x), n))
对于大型n
,效率更高。
请注意:使用其他选项(vcat
)会产生类型稳定的结果,而第一个选项不是类型稳定的。
编辑
通常,原因是您无法确定减少的结果。如果要允许空集合,则应添加init
关键字参数。这是一个示例:
julia> reduce(vcat, [])
ERROR: ArgumentError: reducing over an empty collection is not allowed
julia> reduce(vcat, [], init = [1])
1-element Array{Int64,1}:
1
julia> reduce(vcat, [[2,3], [4,5]], init = [1])
5-element Array{Int64,1}:
1
2
3
4
5
这意味着Julia能够在编译时(在执行代码之前)告诉函数返回值的类型是什么。类型稳定的代码通常运行速度更快(尽管这是一个广泛的话题-我建议您阅读Julia手册以详细了解它)。您可以使用@code_warntype
和Test.@inferred
检查该函数的类型是否稳定。
在这里,请允许我针对您的具体情况进行说明(我将部分输出内容删节以缩短答案)。
julia> x = [1,2,3]
3-element Array{Int64,1}:
1
2
3
julia> y = [4,5,6]
3-element Array{Int64,1}:
4
5
6
julia> @code_warntype vcat(x,y)
Body::Array{Int64,1}
...
julia> @code_warntype cat(x,y, dims=1)
Body::Any
...
julia> using Test
julia> @inferred vcat(x,y)
6-element Array{Int64,1}:
1
2
3
4
5
6
julia> @inferred cat(x,y, dims=1)
ERROR: return type Array{Int64,1} does not match inferred return type Any
上面的 Any
意味着编译器不知道答案的类型。在这种情况下,原因是此类型取决于dims
参数。如果是1
,它将是一个向量;如果是2
,它将是一个矩阵。
n
您可以运行@which
宏:
julia> @which reduce(vcat, [[1,2,3], [4,5,6]])
reduce(::typeof(vcat), A::AbstractArray{#s72,1} where #s72<:(Union{AbstractArray{T,2}, AbstractArray{T,1}} where T)) in Base at abstractarray.jl:1321
您会发现reduce
有专门的vcat
方法。
现在,如果您运行:
@edit reduce(vcat, [[1,2,3], [4,5,6]])
将打开一个编辑器,您会看到它调用了一个内部函数_typed_vcat
,该函数针对vcat
优化了很多数组。引入此优化的原因是,使用像这样的vcat([[1,2,3], [4,5,6]]...)
这样的splatting在结果上是等效的,但是您必须进行splatting(...
),它本身具有一些成本,使用{{1 }}版本。
为了确保我所说的是正确的,您可以执行以下基准测试:
reduce
您会发现julia> using BenchmarkTools
julia> y = [[i] for i in 1:10000];
julia> @benchmark vcat($y...)
BenchmarkTools.Trial:
memory estimate: 156.45 KiB
allocs estimate: 3
--------------
minimum time: 67.200 μs (0.00% GC)
median time: 77.800 μs (0.00% GC)
mean time: 102.804 μs (8.50% GC)
maximum time: 35.179 ms (99.47% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark reduce(vcat, $y)
BenchmarkTools.Trial:
memory estimate: 78.20 KiB
allocs estimate: 2
--------------
minimum time: 67.700 μs (0.00% GC)
median time: 69.700 μs (0.00% GC)
mean time: 82.442 μs (6.39% GC)
maximum time: 32.719 ms (99.58% GC)
--------------
samples: 10000
evals/sample: 1
julia> @benchmark cat($y..., dims=1)
ERROR: StackOverflowError:
版本比reduce
的splatting版本要快一些,而vcat
仅在很大的cat
上失败(对于较小的n
可以,但速度会慢一些。