使用Numba可以更快地制作四个嵌套for循环

时间:2018-05-09 11:18:08

标签: python loops iteration jit numba

我和Numba一起工作有点新鲜,但我得到了它的要点。我想知道是否还有更高级的技巧可以使四个嵌套for循环比我现在拥有的更快。特别是,我需要计算以下积分:

enter image description here

其中B是2D数组,S0和E是某些参数。我的代码如下:

import numpy as np
from numba import njit, double

def calc_gb_gauss_2d(b,s0,e,dx):
    n,m=b.shape
    norm = 1.0/(2*np.pi*s0**2)
    gb = np.zeros((n,m))
    for i in range(n):
        for j in range(m):
            for ii in range(n):
                for jj in range(m):
                    gb[i,j]+=np.exp(-(((i-ii)*dx)**2+((j-jj)*dx)**2)/(2.0*(s0*(1.0+e*b[i,j]))**2))
            gb[i,j]*=norm
    return gb

calc_gb_gauss_2d_nb = njit(double[:, :](double[:, :],double,double,double))(calc_gb_gauss_2d)

对于并输入大小为256x256的数组,计算速度为:

In [4]: a=random.random((256,256))

In [5]: %timeit calc_gb_gauss_2d_nb(a,0.1,1.0,0.5)
The slowest run took 8.46 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 1min 1s per loop

纯Python和Numba计算速度之间的比较给我这样的图片: enter image description here

有没有办法优化我的代码以获得更好的性能?

1 个答案:

答案 0 :(得分:4)

通过使用numpy和一些数学,可以加速你的代码,因此它比当前的numba版本快一个数量级。我们还将看到,使用numba改进的功能使其更快。

通常,numba被过度使用 - 通常可以编写非常有效的numpy代码 - 这也是这种情况。

手头的numpy代码有一个问题:一个人不应该访问单个元素,而是利用numpy的内置函数 - 它们和大多数时候一样快。只有当不可能使用那些numpy函数时,才会使用numba或cython。

然而,这里最大的问题是问题的制定。对于固定的ij,我们有以下公式来计算(我简化了一点):

 g[i,j]=sum_ii sum_jj exp(value_ii+value_jj)
       =sum_ii sum_jj exp(value_ii)*exp(value_jj)
       =sum_ii exp(value_ii) * sum_jj exp(value_jj)

要评估最后一个公式,我们需要O(n+m)个操作,但对于第一个,天真的公式O(n*m) - 非常不同!

利用numpy功能的第一个版本可能类似于:

def calc_ead(b,s0,e,dx):
    n,m=b.shape
    norm = 1.0/(2*np.pi*s0**2)
    gb = np.zeros((n,m))
    vI=np.arange(n)
    vJ=np.arange(m)
    for i in range(n):
        for j in range(m):
            II=(i-vI)*dx
            JJ=(j-vJ)*dx
            denom=2.0*(s0*(1.0+e*b[i,j]))**2
            expII=np.exp(-II*II/denom)
            expJJ=np.exp(-JJ*JJ/denom)
            gb[i,j]=norm*(expII.sum()*expJJ.sum())
    return gb

现在,与最初的numba实现相比:

>>> a=np.random.random((256,256))

>>> print(calc_gb_gauss_2d_nb(a,0.1,1.0,0.5)[1,1])
15.9160709993
>>> %timeit -n1 -r1 calc_gb_gauss_2d_nb(a,0.1,1.0,0.5)
1min 6s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

现在numpy-function:

>>> print(calc_ead(a,0.1,1.0,0.5)[1,1])
15.9160709993
>>> %timeit -n1 -r1 calc_ead(a,0.1,1.0,0.5)
1.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

有两个观察结果:

  1. 结果是一样的。
  2. numpy版本快37倍,对于更大的问题,这种差异会变得更大。
  3. 显然,你可以利用numba进行更大的加速。但是,在可能的情况下使用numpy功能仍然是一个好主意 - 这是非常令人惊讶的,最简单的事情是多么微妙 - 例如甚至calculating a sum

    >>> nb_calc_ead = njit(double[:, :](double[:, :],double,double,double))(calc_ead)
    >>>print(nb_calc_ead(a,0.1,1.0,0.5)[1,1])
    15.9160709993
    >>>%timeit -n1 -r1 nb_calc_ead(a,0.1,1.0,0.5)
    587 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
    

    还有另一个因素3!

    这个问题可以并行化,但要做到这一点并非易事。我的便宜尝试使用explicit loop parallelization

    from numba import njit, prange
    import math
    
    @njit(parallel=True)                 #needed, so it is parallelized
    def parallel_nb_calc_ead(b,s0,e,dx):
        n,m=b.shape
        norm = 1.0/(2*np.pi*s0**2)
        gb = np.zeros((n,m))
        vI=np.arange(n)
        vJ=np.arange(m)
        for i in prange(n):             #outer loop = explicit prange-loop
            for j in range(m):
                denom=2.0*(s0*(1.0+e*b[i,j]))**2
                expII=np.zeros((n,))
                expJJ=np.zeros((m,))
                for k in range(n):
                    II=(i-vI[k])*dx
                    expII[k]=math.exp(-II*II/denom)
    
                for k in range(m):
                    JJ=(j-vJ[k])*dx
                    expJJ[k]=math.exp(-JJ*JJ/denom)
                gb[i,j]=norm*(expII.sum()*expJJ.sum())
        return gb
    

    现在:

    >>> print(parallel_nb_calc_ead(a,0.1,1.0,0.5)[1,1])
    15.9160709993
    >>> %timeit -n1 -r1 parallel_nb_calc_ead(a,0.1,1.0,0.5)
    349 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
    

    意味着几乎另一个因素2(我的机器只有两个CPU,取决于硬件,加速可能更大)。顺便说一句,我们的速度几乎是原始版本的200倍。

    我打赌可以改进上面的代码,但我不会去那里。

    列出与calc_ead进行比较的当前版本:

    import numpy as np
    from numba import njit, double
    
    def calc_gb_gauss_2d(b,s0,e,dx):
        n,m=b.shape
        norm = 1.0/(2*np.pi*s0**2)
        gb = np.zeros((n,m))
        for i in range(n):
            for j in range(m):
                for ii in range(n):
                    for jj in range(m):
                        gb[i,j]+=np.exp(-(((i-ii)*dx)**2+((j-jj)*dx)**2)/(2.0*(s0*(1.0+e*b[i,j]))**2))
                gb[i,j]*=norm
        return gb
    
    calc_gb_gauss_2d_nb = njit(double[:, :](double[:, :],double,double,double))(calc_gb_gauss_2d)