例如,找到下面的图片,这解释了简单2D案例的问题。每个点的标签(N)和坐标(x,y)是已知的。我需要找到位于红圈内的所有点标签
我的实际问题是3D,点数不均匀分布
此处附有包含7.25 M点坐标的示例输入文件point file。
我尝试了下面这段代码
import numpy as np
C = [50,50,50]
R = 20
centroid = np.loadtxt('centroid') #chk the file attached
def dist(x,y): return sum([(xi-yi)**2 for xi, yi in zip(x,y)])
elabels=[i+1 for i in range(len(centroid)) if dist(C,centroid[i])<=R**2]
对于单次搜索,需要约10分钟。有什么建议让它更快?
谢谢, Prithivi
答案 0 :(得分:2)
使用numpy
时,请避免在数组上使用列表推导。
您的计算可以使用像这样的矢量化表达式完成
centre = np.array((50., 50., 50.))
points = np.loadtxt('data')
distances2= np.sum((points-centre)**2, axis=1)
points
是N x 2
数组,points-centre
也是N x 2
数组,
(points-centre)**2
计算差异的每个元素的平方,并且最终np.sum(..., axis=1)
对沿轴no的平方差异的元素求和。 1,即跨列。
要过滤位置数组,可以使用布尔索引
close = points[distances2<max_dist**2]
答案 1 :(得分:1)
你正在大量调用dist
函数。您可以尝试对其进行低级优化,并使用更高效的timeit Python模块进行控制。在我的机器上,我尝试了这个:
def dist(x,y):
d0 = y[0] -x[0]
d1 = y[1] -x[1]
d2 = y[2] -x[2]
return d0 * d0 + d1*d1 + d2*d2
and timeit说它快了3倍多。
这个只是在中间:
def dist(x,y):
s = 0
for i in range(len(x)):
d = y[i] - x[i]
s += d * d
return s