我有几个(~10 ^ 10)点的x,y,z坐标数组(这里只显示了5个)
a= [[ 34.45 14.13 2.17]
[ 32.38 24.43 23.12]
[ 33.19 3.28 39.02]
[ 36.34 27.17 31.61]
[ 37.81 29.17 29.94]]
我想创建一个新数组,其中只包含距离列表中所有其他点至少距离d
的点。我使用while
循环编写了一个代码,
import numpy as np
from scipy.spatial import distance
d=0.1 #or some distance
i=0
selected_points=[]
while i < len(a):
interdist=[]
j=i+1
while j<len(a):
interdist.append(distance.euclidean(a[i],a[j]))
j+=1
if all(dis >= d for dis in interdist):
np.array(selected_points.append(a[i]))
i+=1
这样可行,但执行此计算需要很长时间。我在某处读到while
循环非常慢。
我想知道是否有人对如何加快计算有任何建议。
编辑:虽然我的目标是找到距离所有其他距离至少有一段距离的粒子,但我只是意识到我的代码中有一个严重的缺陷,让我们#39 ;假设我有3个粒子,我的代码执行以下操作,对于i
的第一次迭代,它计算距离1->2
,1->3
,让我们说{{1小于阈值距离1->2
,因此代码会抛弃粒子d
。对于1
的下一次迭代,它只会i
,并且让我们发现它大于2->3
,因此它会保留粒子d
,但这是错的!因为2也应该与粒子2
一起丢弃。 @svohara的解决方案是正确的!
答案 0 :(得分:5)
对于大数据集和低维点(例如三维数据),有时使用空间索引方法有很大好处。低维数据的一个流行选择是k-d树。
策略是索引数据集。然后使用相同的数据集查询索引,以返回每个点的2个最近邻居。第一个最近邻居总是点本身(dist = 0),所以我们真的想知道下一个最近点(第二个最近邻居)有多远。对于那些2-NN> 1的点。阈值,你有结果。
from scipy.spatial import cKDTree as KDTree
import numpy as np
#a is the big data as numpy array N rows by 3 cols
a = np.random.randn(10**8, 3).astype('float32')
# This will create the index, prepare to wait...
# NOTE: took 7 minutes on my mac laptop with 10^8 rand 3-d numbers
# there are some parameters that could be tweaked for faster indexing,
# and there are implementations (not in scipy) that can construct
# the kd-tree using parallel computing strategies (GPUs, e.g.)
k = KDTree(a)
#ask for the 2-nearest neighbors by querying the index with the
# same points
(dists, idxs) = k.query(a, 2)
# (dists, idxs) = k.query(a, 2, n_jobs=4) # to use more CPUs on query...
#Note: 9 minutes for query on my laptop, 2 minutes with n_jobs=6
# So less than 10 minutes total for 10^8 points.
# If the second NN is > thresh distance, then there is no other point
# in the data set closer.
thresh_d = 0.1 #some threshold, equiv to 'd' in O.P.'s code
d_slice = dists[:, 1] #distances to second NN for each point
res = np.flatnonzero( d_slice >= thresh_d )
答案 1 :(得分:2)
这是使用distance.pdist
-
# Store number of pts (number of rows in a)
m = a.shape[0]
# Get the first of pairwise indices formed with the pairs of rows from a
# Simpler version, but a bit slow : idx1,_ = np.triu_indices(m,1)
shifts_arr = np.zeros(m*(m-1)/2,dtype=int)
shifts_arr[np.arange(m-1,1,-1).cumsum()] = 1
idx1 = shifts_arr.cumsum()
# Get the IDs of pairs of rows that are more than "d" apart and thus select
# the rest of the rows using a boolean mask created with np.in1d for the
# entire range of number of rows in a. Index into a to get the selected points.
selected_pts = a[~np.in1d(np.arange(m),idx1[distance.pdist(a) < d])]
对于像10e10
这样的庞大数据集,我们可能必须根据可用的系统内存以块的形式执行操作。
答案 2 :(得分:0)
删除追加,它一定很慢。您可以使用静态矢量距离并使用[]将数字放在正确的位置。
使用min而不是all。您只需要检查最小距离是否大于x。
实际上,你可以在发现距离小于限制的那一刻打破你的追尾,然后你就可以退出两个点。这样你甚至不必保存任何距离(除非你以后需要它们)。
如果你没有重复的观点,我会相信你的评论。
selected_points = []
for p1 in a:
save_point = True
for p2 in a:
if p1!=p2 and distance.euclidean(p1,p2)<d:
save_point = False
break
if save_point:
selected_points.append(p1)
return selected_points
最后,我检查a,b和b,a,因为你不应该在处理它时修改列表,但你可以更聪明地使用一些附加变量。
答案 3 :(得分:0)
你的算法是二次的(10 ^ 20次运算),如果分布几乎是随机的,这是一个线性方法。
将您的空间拆分为大小为d/sqrt(3)^3
的框。将每个点放在其框中。
然后为每个方框
如果只有一个点,你只需要在一个小街区计算距离。
否则无事可做。