嵌套Numpy数组上的Numba

时间:2019-07-11 15:12:34

标签: python arrays numpy matrix numba

设置

我有以下两种矩阵计算实现:

  1. 第一个实现使用matrix of shape (n, m),并且在for循环中重复计算repetition次:
import numpy as np
from numba import jit

@jit
def foo():
    for i in range(1, n):
        for j in range(1, m):

            _deleteA = (
                        matrix[i, j] +
                        #some constants added here
            )
            _deleteB = (
                        matrix[i, j-1] +
                        #some constants added here
            )
            matrix[i, j] = min(_deleteA, _deleteB)

    return matrix

repetition = 3
for x in range(repetition):
    foo()


2.第二种实现方式避免了额外的for循环,因此将repetition = 3包含在矩阵中,该矩阵即为shape (repetition, n, m)的矩阵:

@jit
def foo():
    for i in range(1, n):
        for j in range(1, m):

            _deleteA = (
                        matrix[:, i, j] +
                        #some constants added here
            )
            _deleteB = (
                        matrix[:, i, j-1] +
                        #some constants added here
            )
            matrix[:, i, j] = np.amin(np.stack((_deleteA, _deleteB), axis=1), axis=1)

    return matrix


问题

关于这两种实现,我发现在iPython中使用%timeit的性能有两件事。

  1. 第一个实现从@jit中获得了可观的利润,而第二个实现则根本不赚钱(在我的测试用例中为28ms vs. 25sec)。 有人能想象为什么@jit不再适用于形状为(repetition, n, m)的numpy数组吗?


编辑

我将先前的第二个问题移至an extra post,因为询问多个问题被认为是不好的SO风格。

问题是:

  1. 忽略@jit时,第一个实现仍然要快得多(相同的测试用例:17秒vs. 26秒)。 为什么在三维而不是二维上工作时,numpy的速度会变慢?

1 个答案:

答案 0 :(得分:3)

我不确定您的设置在这里,但是我稍微重写了您的示例:

import numpy as np
from numba import jit

#@jit(nopython=True)
def foo(matrix):
    n, m = matrix.shape
    for i in range(1, n):
        for j in range(1, m):

            _deleteA = (
                        matrix[i, j] #+
                        #some constants added here
            )
            _deleteB = (
                        matrix[i, j-1] #+
                        #some constants added here
            )
            matrix[i, j] = min(_deleteA, _deleteB)

    return matrix

foo_jit = jit(nopython=True)(foo)

然后是时间:

m = np.random.normal(size=(100,50))

%timeit foo(m)  # in a jupyter notebook
# 2.84 ms ± 54.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit foo_jit(m)  # in a jupyter notebook
# 3.18 µs ± 38.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

因此,numba的速度比预期的要快得多。要考虑的一件事是全局numpy数组在numba中的行为不像您期望的那样:

https://numba.pydata.org/numba-doc/dev/user/faq.html#numba-doesn-t-seem-to-care-when-i-modify-a-global-variable

通常最好像我在示例中那样传递数据。

在第二种情况下,您的问题是numba目前不支持amin。参见:

https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html

如果将nopython=True传递给jit,则可以看到此信息。因此,在当前版本的numba(当前为0.44或更早版本)中,它会退回到objectmode,这通常不会比不使用numba快,但有时会变慢,因为存在一些调用开销。