我将实现一个使用递归相当多的程序。所以,在我开始获得堆栈溢出异常之前,我认为实现一个蹦床并在需要时使用thunks会很好。
我做的第一次尝试是使用阶乘法。代码如下:
callable(f) = !isempty(methods(f))
function trampoline(f, arg1, arg2)
v = f(arg1, arg2)
while callable(v)
v = v()
end
return v
end
function factorial(n, continuation)
if n == 1
continuation(1)
else
(() -> factorial(n-1, (z -> (() -> continuation(n*z)))))
end
end
function cont(x)
x
end
另外,我实现了一个天真的阶乘,以检查,实际上,我是否会阻止堆栈溢出:
function factorial_overflow(n)
if n == 1
1
else
n*factorial_overflow(n-1)
end
end
结果是:
julia> factorial_overflow(140000)
ERROR: StackOverflowError:
#JITing with a small input
julia> trampoline(factorial, 10, cont)
3628800
#Testing
julia> trampoline(factorial, 140000, cont)
0
所以,是的,我正在避免使用StacksOverflows。是的,我知道结果是无意义的,因为我得到整数溢出,但在这里我只关心堆栈。当然,生产版本可以修复。
(另外,我知道因为有一个内置的因子,我不会使用其中任何一个,我用它来测试我的蹦床)。
蹦床版本在第一次运行时需要花费大量时间,然后在计算相同或更低值时快速...
如果我trampoline(factorial, 150000, cont)
我会再次编译一段时间。
在我看来(有教养的猜测)我正在为因子进行JITing许多不同的签名:每个thunk生成一个。
我的问题是:我可以避免这个吗?
答案 0 :(得分:1)
我认为问题在于每个闭包都是它自己的类型,它专门用于捕获的变量。为了避免这种专业化,可以使用非完全专业化的仿函数:
struct L1
f
n::Int
z::Int
end
(o::L1)() = o.f(o.n*o.z)
struct L2
f
n::Int
end
(o::L2)(z) = L1(o.f, o.n, z)
struct Factorial
f
c
n::Int
end
(o::Factorial)() = o.f(o.n-1, L2(o.c, o.n))
callable(f) = false
callable(f::Union{Factorial, L1, L2}) = true
function myfactorial(n, continuation)
if n == 1
continuation(1)
else
Factorial(myfactorial, continuation, n)
end
end
function cont(x)
x
end
function trampoline(f, arg1, arg2)
v = f(arg1, arg2)
while callable(v)
v = v()
end
return v
end
请注意,函数字段是无类型的。现在,该功能在第一次运行时运行得更快:
julia> @time trampoline(myfactorial, 10, cont)
0.020673 seconds (4.24 k allocations: 264.427 KiB)
3628800
julia> @time trampoline(myfactorial, 10, cont)
0.000009 seconds (37 allocations: 1.094 KiB)
3628800
julia> @time trampoline(myfactorial, 14000, cont)
0.001277 seconds (55.55 k allocations: 1.489 MiB)
0
julia> @time trampoline(myfactorial, 14000, cont)
0.001197 seconds (55.55 k allocations: 1.489 MiB)
0
我刚刚将代码中的每个闭包翻译成相应的仿函数。这可能不是必需的,可能有更好的解决方案,但它有效,并有希望证明这种方法。
修改强>
为了使减速的原因更清楚,可以使用:
function factorial(n, continuation)
if n == 1
continuation(1)
else
tmp = (z -> (() -> continuation(n*z)))
@show typeof(tmp)
(() -> factorial(n-1, tmp))
end
end
输出:
julia> trampoline(factorial, 10, cont)
typeof(tmp) = ##31#34{Int64,#cont}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,#cont}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}}
typeof(tmp) = ##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,##31#34{Int64,#cont}}}}}}}}}
3628800
tmp
是一个闭包。其自动创建的类型##31#34
看起来类似于
struct Tmp{T,F}
n::T
continuation::F
end
F
字段continuation
类型的特化是编译时间长的原因。
使用L2
而不是专门针对相应字段f
,continuation
factorial
参数始终为L2
类型问题是避免的。