计算python中每个点之间距离的最快方法

时间:2016-05-18 11:48:22

标签: python numpy optimization cython

在我的项目中,我需要计算存储在数组中的每个点之间的欧几里德距离。 入口数组是一个2D numpy数组,有3列,它们是坐标(x,y,z),每行定义一个新点。

我通常在我的测试用例中使用5000 - 6000点。

我的第一个算法使用Cython和我的第二个numpy。我发现我的numpy算法比cython快。

编辑:6000分:

numpy 1.76 s / cython 4.36 s

这是我的cython代码:

cimport cython
from libc.math cimport sqrt
@cython.boundscheck(False)
@cython.wraparound(False)
cdef void calcul1(double[::1] M,double[::1] R):

  cdef int i=0
  cdef int max = M.shape[0]
  cdef int x,y
  cdef int start = 1

  for x in range(0,max,3):
     for y in range(start,max,3):

        R[i]= sqrt((M[y] - M[x])**2 + (M[y+1] - M[x+1])**2 + (M[y+2] - M[x+2])**2)
        i+=1  

     start += 1

M是初始条目数组的内存视图,但是在调用函数flatten()之前是numpy的calcul1(),R是存储所有结果的1D输出数组的内存视图。 / p>

这是我的Numpy代码:

def calcul2(M):

     return np.sqrt(((M[:,:,np.newaxis] - M[:,np.newaxis,:])**2).sum(axis=0))

这里M是初始入口数组,但是在函数调用之前由numpy transpose()将坐标(x,y,z)作为行和点作为列。

此外,这个numpy函数非常方便,因为它返回的数组组织得很好。它是一个n×n数组,n为点数,每个点有一行和一列。因此,例如,距离AB存储在行A和列B的交叉点索引处。

以下是我如何称呼它们(cython功能):

cpdef test():

  cdef double[::1] Mf 
  cdef double[::1] out = np.empty(17998000,dtype=np.float64) # (6000² - 6000) / 2

  M = np.arange(6000*3,dtype=np.float64).reshape(6000,3) # Example array with 6000 points
  Mf = M.flatten() #because my cython algorithm need a 1D array
  Mt = M.transpose() # because my numpy algorithm need coordinates as rows

  calcul2(Mt)

  calcul1(Mf,out)

我在这里做错了吗?对于我的项目,两者都不够快。

1:有没有办法改进我的cython代码以击败numpy的速度?

2:有没有办法改进我的numpy代码以便更快地计算?

3:或任何其他解决方案,但它必须是python / cython(如并行计算)?

谢谢。

1 个答案:

答案 0 :(得分:5)

不确定您的时间安排在哪里,但您可以使用scipy.spatial.distance

M = np.arange(6000*3, dtype=np.float64).reshape(6000,3)
np_result = calcul2(M)
sp_result = sd.cdist(M.T, M.T) #Scipy usage
np.allclose(np_result, sp_result)
>>> True

时序:

%timeit calcul2(M)
1000 loops, best of 3: 313 µs per loop

%timeit sd.cdist(M.T, M.T)
10000 loops, best of 3: 86.4 µs per loop

重要的是,它对于实现输出是对称的也很有用:

np.allclose(sp_result, sp_result.T)
>>> True

另一种方法是仅计算此数组的上三角形:

%timeit sd.pdist(M.T)
10000 loops, best of 3: 39.1 µs per loop

编辑:不确定要压缩哪个索引,看起来你可能两种方式都这样做?压缩另一个索引进行比较:

%timeit sd.pdist(M)
10 loops, best of 3: 135 ms per loop

仍然比您当前的NumPy实施快10倍。