如何编写一个函数,该函数将检查每个调用方法的返回类型是否可静态推断?

时间:2019-09-24 00:17:34

标签: julia

我想编写一个函数,如果茱莉亚无法推断该函数的具体返回类型,则将引发错误。如何在没有任何运行时开销的情况下执行此操作?

1 个答案:

答案 0 :(得分:12)

(如果您的函数体是纯函数)执行此操作的一种方法是使用generated function。例如,假设有问题的功能是

f(x) = x + (rand(Bool) ? 1.0 : 1)

我们可以改写

_f(x) = x + (rand(Bool) ? 1.0 : 1)
@generated function f(x)
    out_type = Core.Compiler.return_type(_f, Tuple{x})
    if !isconcretetype(out_type)
        error("$f($x) does not infer to a concrete type")
    end
    :(_f(x))
end

现在我们可以在repl上对此进行测试。浮点输入很好,但是整数错误:

julia> f(1.0)
2.0

julia> f(1)
ERROR: f(Int64) does not infer to a concrete type
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] #s28#4(::Any, ::Any) at ./REPL[5]:4
 [3] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:524
 [4] top-level scope at REPL[8]:1

并且由于我们使用生成的函数的方式,类型检查和错误抛出仅在编译时发生,因此我们不为此付出任何运行时成本。

如果上面的代码对您来说似乎太多了,我们可以编写一个宏来自动生成内部函数和为任意函数签名生成的函数:

using MacroTools: splitdef, combinedef

strip_type_asserts(ex::Expr) = ex.head == :(::) ? ex.args[1] : ex
strip_type_asserts(s) = s

macro checked(fdef)
    d = splitdef(fdef)

    f = d[:name]
    args = d[:args]
    whereparams = d[:whereparams]

    d[:name] = gensym()
    shadow_fdef = combinedef(d)

    args_stripped = strip_type_asserts.(args)

    quote
        $shadow_fdef
        @generated function $f($(args...)) where {$(whereparams...)}
            d = $d
            T = Tuple{$(args_stripped...)}
            shadowf = $(d[:name])
            out_type = Core.Compiler.return_type(shadowf, T)
            sig = collect(T.parameters)
            if !isconcretetype(out_type)
                f = $f
                sig = reduce(*, (", $U" for U in T.parameters[2:end]), init="$(T.parameters[1])")
                error("$f($(sig...)) does not infer to a concrete type")
            end
            args = $args
            #Core.println("statically inferred return type was $out_type")
            :($(shadowf)($(args...)))
        end
    end |> esc
end

现在在REPL中,我们只需要用@checked注释一个函数定义:

julia> @checked g(x, y) = x + (rand(Bool) ? 1.0 : 1)*y
f (generic function with 2 methods)

julia> g(1, 2.0)
3.0

julia> g(1, 2)
ERROR: g(Int64, Int64) does not infer to a concrete type
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] #s28#5(::Any, ::Any, ::Any) at ./REPL[11]:22
 [3] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at ./boot.jl:524
 [4] top-level scope at REPL[14]:1

编辑:在评论中已指出,我违反了此处使用生成函数的“规则”之一,因为如果有人将编译时在生成函数中发生的情况默默地作废重新定义@checked函数所依赖的函数。例如:

julia> g(x) = x + 1;

julia> @checked f(x) = g(x) + 1;

julia> f(1) # shouldn't error
3

julia> g(x) = rand(Bool) ? 1.0 : 1
g (generic function with 1 method)

julia> f(1) # Should error but doesn't!!!
2.0

julia> f(1)
2

因此请注意:如果您以交互方式使用此类内容,请小心重新定义您依赖的功能。如果出于某种原因决定在程序包中使用此宏,则请注意,实施盗版的人将使您的类型检查无效。

如果有人要尝试将此技术应用于重要代码,我建议您重新考虑,或者认真考虑如何使此技术更安全。如果您有任何提高安全性的想法,我很乐意听到!每次更改从属方法时,也许可以采取一些技巧来强制重新编译功能。