我正在努力学习cython;但是,我一定是做错了。这个小小的测试代码运行速度比我的矢量化numpy版本慢了大约50倍。有人可以告诉我为什么我的cython比我的python慢?谢谢。
该代码计算R ^ 3,loc中的点与R ^ 3中的点数组之间的距离。
import numpy as np
cimport numpy as np
import cython
cimport cython
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points, np.ndarray[DTYPE_t, ndim=1] loc):
cdef unsigned int i
cdef unsigned int L = points.shape[0]
cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
for i in xrange(0,L):
d[i] = np.sqrt((points[i,0] - loc[0])**2 + (points[i,1] - loc[1])**2 + (points[i,2] - loc[2])**2)
return d
这是与它进行比较的numpy代码。
from numpy import *
N = 1e6
points = random.uniform(0,1,(N,3))
loc = random.uniform(0,1,(3))
def distMeasureNumpy(points,loc):
d = points - loc
d = sqrt(sum(d*d,axis=1))
return d
numpy / python版本大约需要44毫秒,cython版本大约需要2秒。我在mac osx上运行python 2.7。我正在使用ipython的%timeit命令来计时这两个函数。
答案 0 :(得分:5)
对np.sqrt
的调用,这是一个Python函数调用,正在扼杀你的性能你正在计算标量浮点值的平方根,所以你应该使用C数学中的sqrt
函数图书馆。这是您的代码的修改版本:
import numpy as np
cimport numpy as np
import cython
cimport cython
from libc.math cimport sqrt
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False) # turn of bounds-checking for entire function
@cython.wraparound(False)
@cython.nonecheck(False)
def distMeasureCython(np.ndarray[DTYPE_t, ndim=2] points,
np.ndarray[DTYPE_t, ndim=1] loc):
cdef unsigned int i
cdef unsigned int L = points.shape[0]
cdef np.ndarray[DTYPE_t, ndim=1] d = np.zeros(L)
for i in xrange(0,L):
d[i] = sqrt((points[i,0] - loc[0])**2 +
(points[i,1] - loc[1])**2 +
(points[i,2] - loc[2])**2)
return d
以下说明了性能改进。您的原始代码位于模块check_speed_original
中,修改后的版本位于check_speed
:
In [11]: import check_speed_original
In [12]: import check_speed
设置测试数据:
In [13]: N = 10**6
In [14]: points = random.uniform(0,1,(N,3))
In [15]: loc = random.uniform(0,1,(3,))
我的电脑上的原始版本需要1.26秒:
In [16]: %timeit check_speed_original.distMeasureCython(points, loc)
1 loops, best of 3: 1.26 s per loop
修改后的版本需要4.47 毫秒:
In [17]: %timeit check_speed.distMeasureCython(points, loc)
100 loops, best of 3: 4.47 ms per loop
如果有人担心结果可能会有所不同:
In [18]: d1 = check_speed.distMeasureCython(points, loc)
In [19]: d2 = check_speed_original.distMeasureCython(points, loc)
In [20]: np.all(d1 == d2)
Out[20]: True
答案 1 :(得分:3)
如前所述,它是代码中的numpy.sqrt调用。但是,我认为不需要使用cdef extern
,因为Cython已经提供了这些基本的C / C ++库。 (见the docs)。所以你可以这样插入它:
from libc.math cimport sqrt
只是为了摆脱开销。