我试图找到两组数组之间的最短距离。 x-数组是相同的,只包含整数。这是我想要做的一个例子:
import numpy as np
x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)
def dis(x1, y1, x2, y2):
return sqrt((y2-y1)**2+(x2-x1)**2)
min_distance = np.inf
for a, b in zip(x1, y1):
for c, d in zip(x2, y2):
if dis(a, b, c, d) < min_distance:
min_distance = dis(a, b, c, d)
>>> min_distance
2.2360679774997898
此解决方案有效,但问题是运行时。如果x的长度为~10,000,则解决方案是不可行的,因为程序ha O(n ^ 2)运行时。现在,我尝试做一些近似来加速程序:
for a, b in zip(x1, y1):
cut = (x2 > a-20)*(x2 < a+20)
for c, d in zip(x2, y2):
if dis(a, b, c, d) < min_distance:
min_distance = dis(a, b, c, d)
但该计划仍然比我想要的时间更长。现在,根据我的理解,循环一个numpy数组通常是低效的,所以我确信仍有改进的余地。关于如何加速这个程序的任何想法?
答案 0 :(得分:1)
您的问题也可以表示为2d碰撞检测,因此quadtree可能有所帮助。插入和查询都在O(log n)时间运行,因此整个搜索将在O(n log n)中运行。
还有一个建议,因为sqrt是单调的,你可以比较距离的平方而不是距离本身,这将节省你n ^ 2平方根计算。
答案 1 :(得分:1)
scipy
有一个cdist
function来计算所有点对之间的距离:
from scipy.spatial.distance import cdist
import numpy as np
x1 = x2 = np.linspace(-1000, 1000, 2001)
y1 = (lambda x, a, b: a*x + b)(x1, 2, 1)
y2 = (lambda x, a, b: a*(x-2)**2 + b)(x2, 2, 10)
R1 = np.vstack((x1,y1)).T
R2 = np.vstack((x2,y2)).T
dists = cdist(R1,R2) # find all mutual distances
print (dists.min())
# output: 2.2360679774997898
这比原来的循环速度快250倍。
答案 2 :(得分:0)
这是一个难题,如果您愿意接受近似值,它可能会有所帮助。我会查看Spottify的annoy之类的内容。