在numba中缓存jit编译的函数

时间:2019-07-08 13:59:12

标签: python numba

我想使用numba编译一系列的函数,并且由于我只需要在具有相同签名的计算机上运行它们,所以我想缓存它们。 但是,当尝试这样做时,numba告诉我该函数无法缓存,因为它使用了较大的全局数组。这是它显示的特定警告。

  

Numba警告:由于使用动态全局变量(例如ctypes指针和大型全局数组),因此无法缓存已编译的函数“ sigmoid”

我知道全局数组通常是冻结的,但是大数组却没有,但是我的函数看起来像这样:

@njit(parallel=True, cache=True)
def sigmoid(x):
    return 1./(1. + np.exp(-x))

我看不到任何全局数组,尤其是大型数组。
问题出在哪里?

1 个答案:

答案 0 :(得分:0)

即使对于非常简单的测试,我也观察到了此行为(在Windows 10,Dell Latitude 7480,Git for Windows上运行)。看来parallel=True不允许缓存。这与prange调用的实际存在无关。下面是一个简单的示例。

def where_numba(arr: np.ndarray) -> np.ndarray:
    l0, l1 = np.shape(arr)[0], np.shape(arr)[1]
    for i0 in prange(l0):
        for i1 in prange(l1):
            if arr[i0, i1] > 0.5:
                arr[i0, i1] = arr[i0, i1] * 10
    return(arr)

where_numba_jit = jit(signature_or_function='float64[:,:](float64[:,:])',
                  nopython=True, parallel=True, cache=True, fastmath=True, nogil=True)(where_numba)

arr = np.random.random((10000, 10000))
seln = where_numba_jit(arr)

我得到同样的警告。

我认为您可能会考虑自己的特定代码,并查看哪个选项(cacheparallel)最好保留,显然cache是相对较短的计算时间,而parallel与实际计算时间相比,编译时间可以忽略不计。请发表评论,如果您有更新。

与此有关的还有一个 Numba公开问题https://github.com/numba/numba/issues/2439