julia:避免许多函数参数的高效而整洁的方法

时间:2018-06-14 15:28:44

标签: struct julia numerical-methods

我一直在Julia编写随机PDE模拟,随着我的问题变得越来越复杂,独立参数的数量也在增加。首先,

myfun(N,M,dt,dx,a,b)

最终成为

myfun(N,M,dt,dx,a,b,c,d,e,f,g,h)

导致(1)代码混乱,(2)由于错误的函数参数导致错误的机会增加,(3)无法推广用于其他函数。

(3)很重要,因为我已经对我的代码进行了简单的并行化,以评估PDE的许多不同运行。所以我想将我的函数转换为一个表单:

myfun(args)

其中args包含所有相关参数。我在Julia中发现的问题是,创建一个包含所有相关参数作为属性的struct可以大大降低速度。我认为这是由于不断访问struct属性。作为一个简单的(ODE)工作示例,

function example_fun(N,dt,a,b)
  V = zeros(N+1)
  U = 0
  z = randn(N+1)
  for i=2:N+1
    V[i] = V[i-1]*(1-dt)+U*dt
    U = U*(1-dt/a)+b*sqrt(2*dt/a)*z[i]
  end
  return V
end

如果我尝试将其重写为,

function example_fun2(args)
  V = zeros(args.N+1)
  U = 0
  z = randn(args.N+1)
  for i=2:args.N+1
    V[i] = V[i-1]*(1-args.dt)+U*args.dt
    U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
  end
  return V
end

然后,当函数调用看起来很优雅时,重新访问类中的每个属性是很麻烦的,并且这种连续访问属性会减慢模拟速度。什么是更好的解决方案?有没有办法简单地解开'结构的属性,所以不必连续访问它们?如果是这样,这将如何推广?

编辑: 我正在定义我使用的结构如下:

struct Args
  N::Int64
  dt::Float64
  a::Float64
  b::Float64
end

edit2:我已经意识到如果你没有在struct定义中指定数组的维度,那么带有Array {}属性的结构会产生性能差异。例如,如果c是一维参数数组,

struct Args_1
  N::Int64
  c::Array{Float64}
end

将使f(args)的性能远远低于f(N,c)。但是,如果我们指定c是结构定义中的一维数组,

struct Args_1
  N::Int64
  c::Array{Float64,1}
end

然后性能损失消失。我的函数定义中显示的这个问题和类型不稳定似乎是我在使用struct作为函数参数时遇到的性能差异。

2 个答案:

答案 0 :(得分:1)

也许你没有声明args的类型声明参数的类型?

考虑这个小例子:

struct argstype
    N
    dt
end
myfun(args) = args.N * args.dt

myfun不是类型稳定的,无法推断返回类型的类型:

@code_warntype myfun(argstype(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype

Body:
  begin 
      return ((Core.getfield)(args::argstype, :N)::Any * (Core.getfield)(args::argstype, :dt)::Any)::Any
  end::Any

但是,如果声明类型,则代码将变为类型稳定:

struct argstype2
    N::Int
    dt::Float64
end

@code_warntype myfun(argstype2(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype2

Body:
  begin 
      return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype2, :N)::Int64)::Float64, (Core.getfield)(args::argstype2, :dt)::Float64)::Float64
  end::Float64

您会看到Float64的推断返回类型。 使用参数类型(https://docs.julialang.org/en/v0.6.3/manual/types/#Parametric-Types-1),您的代码仍然保持通用和类型稳定:

struct argstype3{T1,T2}
    N::T1
    dt::T2
end

@code_warntype myfun(argstype3(10,0.1))
Variables:
  #self# <optimized out>
  args::argstype3{Int64,Float64}

Body:
  begin 
      return (Base.mul_float)((Base.sitofp)(Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :N)::Int64)::Float64, (Core.getfield)(args::argstype3{Int64,Float64}, :dt)::Float64)::Float64
  end::Float64

答案 1 :(得分:1)

在你的代码中有一个类型不稳定,与U相关,它被初始化为0(整数),但是如果用0替换它(浮点数),类型不稳定性就会消失。

对于原始版本(&#34; U = 0&#34;),函数example_fun需要801.933 ns(对于参数10,0.1,2。,3。)和example_fun2 925.323 ns(对于类似的值)。

在类型稳定版本(U = 0。)中,两者都需要273 ns(+ / 5 ns)。因此,这是一个实质性的加速,并且不再有在类型args中组合参数的惩罚。

这是完整的功能:

function example_fun2(args)
     V = zeros(args.N+1)
     U = 0.
     z = randn(args.N+1)
     for i=2:args.N+1
       V[i] = V[i-1]*(1-args.dt)+U*args.dt
       U = U*(1-args.dt/args.a)+args.b*sqrt(2*args.dt/args.a)*z[i]
     end
     return V
end