Cython加速并不像预期的那么大

时间:2015-10-11 17:19:23

标签: python numpy cython

我编写了一个Python函数,用于计算大数(N~10 ^ 3)粒子之间的成对电磁相互作用,并将结果存储在NxN complex128 ndarray中。它运行,但它是较大程序中最慢的部分,当N = 900 [校正]时需要大约40秒。原始代码如下所示:

import numpy as np
def interaction(s,alpha,kprop): # s is an Nx3 real array 
                                # alpha is complex
                                # kprop is float

    ndipoles = s.shape[0]

    Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=np.complex128)
    I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    im = complex(0,1)

    k2 = kprop*kprop

    for i in range(ndipoles):
        xi = s[i,:]
        for j in range(ndipoles):
            if i != j:
                xj = s[j,:]
                dx = xi-xj
                R = np.sqrt(dx.dot(dx))
                n = dx/R
                kR = kprop*R
                kR2 = kR*kR
                A = ((1./kR2) - im/kR)
                nxn = np.outer(n, n)
                nxn = (3*A-1)*nxn + (1-A)*I
                nxn *= -alpha*(k2*np.exp(im*kR))/R
            else:
                nxn = I

            Amat[i,:,j,:] = nxn

    return(Amat.reshape((3*ndipoles,3*ndipoles)))

我之前从未使用过Cython,但这似乎是开始加快速度的好地方,所以我几乎盲目地调整了我在在线教程中找到的技术。我得到了一些加速(30秒对40秒),但没有我想象的那么戏剧性,所以我想知道我是做错了什么还是错过了关键的一步。以下是我对上述例程进行cython化的最佳尝试:

import numpy as np
cimport numpy as np

DTYPE = np.complex128
ctypedef np.complex128_t DTYPE_t

def interaction(np.ndarray s, DTYPE_t alpha, float kprop):

    cdef float k2 = kprop*kprop
    cdef int i,j
    cdef np.ndarray xi, xj, dx, n, nxn
    cdef float R, kR, kR2
    cdef DTYPE_t A

    cdef int ndipoles = s.shape[0]
    cdef np.ndarray Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=DTYPE)
    cdef np.ndarray I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    cdef DTYPE_t im = complex(0,1)

    for i in range(ndipoles):
        xi = s[i,:]
        for j in range(ndipoles):
            if i != j:
                xj = s[j,:]
                dx = xi-xj
                R = np.sqrt(dx.dot(dx))
                n = dx/R
                kR = kprop*R
                kR2 = kR*kR
                A = ((1./kR2) - im/kR)
                nxn = np.outer(n, n)
                nxn = (3*A-1)*nxn + (1-A)*I
                nxn *= -alpha*(k2*np.exp(im*kR))/R
            else:
                nxn = I

            Amat[i,:,j,:] = nxn

    return(Amat.reshape((3*ndipoles,3*ndipoles)))

1 个答案:

答案 0 :(得分:11)

NumPy的真正强大之处在于以矢量化方式在大量元素上执行操作,而不是在遍布循环的块中使用该操作。在您的情况下,您使用两个嵌套循环和一个IF条件语句。我建议扩展中间数组的维度,这将导致the decorator pattern发挥作用,因此可以一次性对所有元素使用相同的操作,而不是循环中的小块数据。

为了扩展尺寸,可以使用NumPy's powerful broadcasting capability。因此,遵循这样一个前提的矢量化实现看起来像这样 -

In [703]: N = 10
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [704]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [705]: %timeit interaction(s,alpha,kprop)
100 loops, best of 3: 7.6 ms per loop

In [706]: %timeit vectorized_interaction(s,alpha,kprop)
1000 loops, best of 3: 304 µs per loop

运行时测试和输出验证 -

案例#1:

In [707]: N = 100
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [708]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [709]: %timeit interaction(s,alpha,kprop)
1 loops, best of 3: 826 ms per loop

In [710]: %timeit vectorized_interaction(s,alpha,kprop)
100 loops, best of 3: 14 ms per loop

案例#2:

In [711]: N = 900
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [712]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [713]: %timeit interaction(s,alpha,kprop)
1 loops, best of 3: 1min 7s per loop

In [714]: %timeit vectorized_interaction(s,alpha,kprop)
1 loops, best of 3: 1.59 s per loop

案例#3:

{{1}}