Cythonized功能出乎意料地慢了

时间:2012-10-13 20:45:13

标签: python numpy cython

我想加快一个我正在使用的功能,我想使用cython。但是,在尝试了我在文档中找到的所有可能的cython优化之后,cython代码比python + numpy函数慢大约6倍。令人失望!

这是我的测试代码:(forward1是python函数,forward2是cython函数)

#geometry.py
def forward1(points, rotation, translation):
    '''points are in columns'''
    return np.dot(rotation, points - translation[:, np.newaxis])

#geometry.pyx
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
cdef np.float64_t[:,:] forward2(np.float64_t[:,:] points, np.float64_t[:,:] rotation, np.float64_t[:] translation):
    '''points are in columns'''
    cdef unsigned int I, J
    I = points.shape[0]
    J = points.shape[1]
    cdef np.float64_t[:,:] tmp = np.empty((I, J), dtype=np.float64)
    cdef unsigned int i
    for i in range(J):
        tmp[0, i] = points[0, i] - translation[0]        
        tmp[1, i] = points[1, i] - translation[1]        
    cdef np.float64_t[:,:] result = np.dot(rotation, tmp)
    return result

def test_forward2(points, rotation, translation):
    import timeit
    cdef np.float64_t[:,:] points2 = points
    cdef np.float64_t[:,:] rotation2 = rotation
    cdef np.float64_t[:] translation2 = translation
    t = timeit.Timer(lambda: forward2(points2, rotation2, translation2))
    print min(t.repeat(3, 10))

然后我计时:

t = timeit.Timer(lambda: forward1(points, rotation, translation))
print min(t.repeat(3, 10))
0.000368164520751

test_forward2(points, rotation, translation)
0.0023365181969

我可以对cython代码做些什么来加快速度吗?

如果在cython中无法加速forward1,我可以希望使用编织加速吗?

修改

仅仅是为了记录,我试图加速这个功能的另一件事是按照fortran顺序传递点,因为我的点存储在列中并且有很多它们。我还将本地tmp定义为fortran顺序。我认为函数的减法部分应该更快但numpy.dot似乎需要一个C阶输出(无论如何解决这个问题?),所以总的来说也没有加速。我还尝试转换点,以便减法部分以C顺序更快,但似乎点积仍然是最昂贵的部分。

另外,我注意到numpy.dot不能将memoryviews用作out参数,即使它是C顺序,这是一个bug吗?

1 个答案:

答案 0 :(得分:4)

只是看一下你的代码,它看起来像是numpy已经非常优化的东西(数组和点积的减法)。

Cython非常适合加速numpy经常表现不佳的情况(例如迭代算法,其中迭代是用python编写的),但在这种情况下,内部循环已经由BLAS库执行。

如果你想加快速度,我首先要看的是BLAS / LAPACK / ATLAS / etc库numpy与之相关联的内容。使用“调谐”线性代数库(例如ATLAS或英特尔MKL)将在这种情况下产生大的(在某些情况下> 10倍)差异。

要了解您当前使用的内容,请查看numpy.show_config()

的输出