为什么依赖的numba jitt'ed函数的顺序很重要?

时间:2019-03-28 14:30:22

标签: python jit numba

在python中,您可以定义多个以任意顺序相互调用的函数,并且在运行时将调用这些函数。一旦存在,这些功能在脚本中定义的顺序就无关紧要。例如,以下内容有效,并且可以正常工作

import numpy as np

def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])  # calling func2 here which is defined below
    return out

def func2(a):
    out = a + 1
    return out

func1可以调用func2,即使func2是在func1之后定义的。

但是,如果我用numba装饰这些功能,则会收到错误消息

import numpy as np
import numba as nb


@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])
    return out

@nb.jit("f8(f8)", nopython=True)
def func2(a):
    out = a + 1
    return out

>>> TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    Untyped global name 'func2': cannot determine Numba type of <class 
    'numba.ir.UndefinedType'>

因此,numba在使用JIT编译func2时不知道func1是什么。不过,只需切换这些功能的顺序即可,因此func2func1

之前
@nb.jit("f8(f8)", nopython=True)
def func2(a):
    out = a + 1
    return out

@nb.jit("f8[:](f8[:])", nopython=True)
def func1(arr):
    out = np.empty_like(arr)
    for i in range(arr.shape[0]):
        out[i] = func2(arr[i])
    return out

这是为什么?我有一种纯粹的python模式可以使用的感觉,因为python是动态类型化的,而不是编译的,而numba根据定义使用JIT确实可以编译函数(因此可能需要对每个函数中发生的一切都有全面的了解?)。但是我不明白,如果numba遇到未见过的功能,为什么它不在所有功能范围内搜索。

1 个答案:

答案 0 :(得分:3)

简短版本-删除"f8[:](f8[:])"

您的直觉是正确的。在 call 时检查Python函数,这就是为什么可以不按顺序定义它们的原因。通过dis(反汇编)模块查看python字节码可以使这一点变得很清楚-每次调用函数b时,名称a都会被查询为全局变量。

def a():
    return b()

def b():
    return 2

import dis
dis.dis(a)
#  2           0 LOAD_GLOBAL              0 (b)
#              2 CALL_FUNCTION            0
#              4 RETURN_VALUE

在nopython模式下,numba需要静态知道正在调用的每个函数的地址-这样可以使代码运行更快(不再执行运行时查找),并且还为其他代码打开了大门。优化,例如内联。

也就是说,numba 可以处理这种情况。通过指定类型签名("f8[:](f8[:])"),您将强制提前编译。忽略它,一个数字将延迟到第一个调用它的函数,它将起作用。