我一直在Julia编写随机PDE模拟,随着我的问题变得越来越复杂,独立参数的数量也在增加。首先,
myfun(N,M,dt,dx,a,b)
最终成为
myfun(N,M,dt,dx,a,b,c,d,e,f,g,h)
导致(1)代码混乱,(2)由于错误的函数参数导致错误的机会增加,(3)无法推广用于其他函数。
(3)很重要,因为我已经对我的代码进行了简单的并行化,以评估PDE的许多不同运行。所以我想将我的函数转换为一个表单:
myfun(args)
其中args包含所有相关参数。我在Julia中发现的问题是,创建一个包含所有相关参数作为属性的struct
可以大大降低速度。我认为这是由于不断访问struct属性。作为一个简单的(ODE)工作示例,
function example_fun(N,dt,a,b)
V = zeros(N+1)
U = 0
z = randn(N+1)
for i=2:N+1
V[i] = V[i-1]*(1-dt)+U*dt
U = U*(1-dt/a)+b*sqrt(2*dt/a)*z[i]
end
return V
end
如果我尝试将其重写为,
function example_fun2(args)
V = zeros(args.N+1)
U = 0
z = randn(args.N+1)
for i=2:args.N+1
V[i] = V[i-1]*(1-args.dt)+U*args.dt
U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
end
return V
end
然后,当函数调用看起来很优雅时,重新访问类中的每个属性是很麻烦的,并且这种连续访问属性会减慢模拟速度。什么是更好的解决方案?有没有办法简单地解开'结构的属性,所以不必连续访问它们?如果是这样,这将如何推广?
编辑: 我正在定义我使用的结构如下:
struct Args
N::Int64
dt::Float64
a::Float64
b::Float64
end
edit2:我已经意识到如果你没有在struct定义中指定数组的维度,那么带有Array {}属性的结构会产生性能差异。例如,如果c是一维参数数组,
struct Args_1
N::Int64
c::Array{Float64}
end
将使f(args)的性能远远低于f(N,c)。但是,如果我们指定c是结构定义中的一维数组,
struct Args_1
N::Int64
c::Array{Float64,1}
end
然后性能损失消失。我的函数定义中显示的这个问题和类型不稳定似乎是我在使用struct作为函数参数时遇到的性能差异。
答案 0 :(得分:1)
也许你没有声明args的类型声明参数的类型?
考虑这个小例子:
struct argstype
N
dt
end
myfun(args) = args.N * args.dt
myfun
不是类型稳定的,无法推断返回类型的类型:
@code_warntype myfun(argstype(10,0.1))
Variables:
#self# <optimized out>
args::argstype
Body:
begin
return ((Core.getfield)(args::argstype, :N)::Any * (Core.getfield)(args::argstype, :dt)::Any)::Any
end::Any
但是,如果声明类型,则代码将变为类型稳定:
struct argstype2
N::Int
dt::Float64
end
@code_warntype myfun(argstype2(10,0.1))
Variables:
#self# <optimized out>
args::argstype2
Body:
begin
return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype2, :N)::Int64)::Float64, (Core.getfield)(args::argstype2, :dt)::Float64)::Float64
end::Float64
您会看到Float64的推断返回类型。 使用参数类型(https://docs.julialang.org/en/v0.6.3/manual/types/#Parametric-Types-1),您的代码仍然保持通用和类型稳定:
struct argstype3{T1,T2}
N::T1
dt::T2
end
@code_warntype myfun(argstype3(10,0.1))
Variables:
#self# <optimized out>
args::argstype3{Int64,Float64}
Body:
begin
return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :N)::Int64)::Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :dt)::Float64)::Float64
end::Float64
答案 1 :(得分:1)
在你的代码中有一个类型不稳定,与U相关,它被初始化为0(整数),但是如果用0替换它(浮点数),类型不稳定性就会消失。
对于原始版本(&#34; U = 0&#34;),函数example_fun需要801.933 ns(对于参数10,0.1,2。,3。)和example_fun2 925.323 ns(对于类似的值)。
在类型稳定版本(U = 0。)中,两者都需要273 ns(+ / 5 ns)。因此,这是一个实质性的加速,并且不再有在类型args中组合参数的惩罚。
这是完整的功能:
function example_fun2(args)
V = zeros(args.N+1)
U = 0.
z = randn(args.N+1)
for i=2:args.N+1
V[i] = V[i-1]*(1-args.dt)+U*args.dt
U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
end
return V
end