找到两个阵列之间最短距离的有效方法?

时间:2016-07-29 22:41:56

标签: python arrays numpy runtime ipython

我试图找到两组数组之间的最短距离。 x-数组是相同的,只包含整数。这是我想要做的一个例子:

import numpy as np
x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)

def dis(x1, y1, x2, y2):
    return sqrt((y2-y1)**2+(x2-x1)**2)

min_distance = np.inf
for a, b in zip(x1, y1):
    for c, d in zip(x2, y2):
        if dis(a, b, c, d) < min_distance:
            min_distance = dis(a, b, c, d)

>>> min_distance
2.2360679774997898

此解决方案有效,但问题是运行时。如果x的长度为~10,000,则解决方案是不可行的,因为程序ha O(n ^ 2)运行时。现在,我尝试做一些近似来加速程序:

for a, b in zip(x1, y1):
    cut = (x2 > a-20)*(x2 < a+20)
    for c, d in zip(x2, y2):
        if dis(a, b, c, d) < min_distance:
            min_distance = dis(a, b, c, d)

但该计划仍然比我想要的时间更长。现在,根据我的理解,循环一个numpy数组通常是低效的,所以我确信仍有改进的余地。关于如何加速这个程序的任何想法?

3 个答案:

答案 0 :(得分:1)

您的问题也可以表示为2d碰撞检测,因此quadtree可能有所帮助。插入和查询都在O(log n)时间运行,因此整个搜索将在O(n log n)中运行。

还有一个建议,因为sqrt是单调的,你可以比较距离的平方而不是距离本身,这将节省你n ^ 2平方根计算。

答案 1 :(得分:1)

scipy有一个cdist function来计算所有点对之间的距离:

from scipy.spatial.distance import cdist
import numpy as np

x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)

R1 = np.vstack((x1,y1)).T
R2 = np.vstack((x2,y2)).T

dists = cdist(R1,R2) # find all mutual distances

print (dists.min())
# output: 2.2360679774997898

这比原来的循环速度快250倍。

答案 2 :(得分:0)

这是一个难题,如果您愿意接受近似值,它可能会有所帮助。我会查看Spottify的annoy之类的内容。