为什么这段代码不能用 Numba 编译?

时间:2021-07-17 22:05:29

标签: python jit numba

我有一个示例代码来说明我的问题。如果你运行:

import numpy as np
from numba import jit


@jit(nopython=True)
def test():
    arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])

    arr2 = arr[:, 0, :]

    arr3 = arr2.argsort()

    print(arr3)

test()

它会失败:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of BoundFunction(array.argsort for array(int64, 2d, A)) with parameters ()
During: resolving callee type: BoundFunction(array.argsort for array(int64, 2d, A))
During: typing of call at /home/stark/Work/mmr6/test.py (41)


File "test.py", line 41:
def test():
    <source elided>

    arr3 = arr2.argsort()
    ^

argsort 应该在最后一个轴上进行 argsort。基本上它应该给我:

>>>
[[0 1 2]
 [0 1 2]]

我认为复制 arr2 数组(使用 copy())可以解决,因为它会使数组在内存中连续(而不是视图),但它失败并显示相同的消息,只是类型不同消息中的 arr2 现在符合预期的 array(int64, 2d, C)

为什么会失败,我该如何解决?

1 个答案:

答案 0 :(得分:3)

遗憾的是,这是目前已知的 Numba 限制。见this issue。目前只支持一维数组。但是,您的情况有一个简单的解决方法:

import numpy as np
from numba import jit


@jit(nopython=True)
def test():
    arr = np.array([[[11, 12, 13], [11, 12, 13]], [[21, 22, 23], [21, 22, 23]]])

    arr2 = arr[:, 0, :]

    arr3 = np.empty(arr2.shape, dtype=arr2.dtype)
    for i in range(arr2.shape[0]):
        arr3[i] = arr2[i, :].argsort()

    print(arr3)

test()

请注意,即使实现了,也不会更快。见this issue。实际上,对于任何给定的 Numpy 原语,Numba 没有理由更快。但是,您可以使用 Numba 手动编写自己的 Numpy 原语版本,有时由于算法专业化、并行性或数学优化(例如快速数学),速度会有所提高。当您想要执行 Numpy 中尚未/直接可用的有效操作时,Numba 通常非常有用,并且可以使用循环轻松实现此操作。

实际上,假设 prange 尚未并行运行,您可以使用 Numba 的 parallel=True 和 JIT 参数 argsort 来加快计算速度(AFAIK 它应该是顺序的) .这应该比在大数组上(在小数组上,产生多个线程的成本可能大于实际计算的成本)上的 Numpy 实现(也不应该按顺序运行)快一点。