我想删除与之前数据相差10cm的数据。
这就是我所拥有的,但它需要大量的计算时间,因为我的数据集非常庞大
for i in range(len(data)):
for j in range(i, len(data)):
if (i == j):
continue
elif np.sqrt((data[i, 0]-data[j, 0])**2 + (data[i, 1]-data[i, 1])**2) <= 0.1:
data[j, 0] = np.nan
data = data[~np.isnan(data).any(axis=1)]
有没有pythonic的方法来做到这一点?
答案 0 :(得分:3)
以下是使用KDTree:
的方法import numpy as np
from scipy.spatial import cKDTree as KDTree
def cluster_data_KDTree(a, thr=0.1):
t = KDTree(a)
mask = np.ones(a.shape[:1], bool)
idx = 0
nxt = 1
while nxt:
mask[t.query_ball_point(a[idx], thr)] = False
nxt = mask[idx:].argmax()
mask[idx] = True
idx += nxt
return a[mask]
借用@Divakar的测试用例,我们发现这会在100x
Divakar报告之上提供另一个400x
加速。与OP相比,我们推断出一个荒谬的40,000x
:
np.random.seed(0)
data1 = np.random.rand(10000,2)
data2 = data1.copy()
from timeit import timeit
kwds = dict(globals=globals(), number=10)
print(timeit("cluster_data_KDTree(data1)", **kwds))
print(timeit("cluster_data_pdist_v1(data2)", **kwds))
np.random.seed(0)
data1 = np.random.rand(10000,2)
data2 = data1.copy()
out1 = cluster_data_KDTree(data1, thr=0.1)
out2 = cluster_data_pdist_v1(data2, dist_thresh = 0.1)
print(np.allclose(out1, out2))
示例输出:
0.05073001119308174
5.646531613077968
True
事实证明,这个测试用例恰好对我的方法非常有利,因为集群非常少,因此迭代次数很少。
如果通过将阈值更改为3800
0.01
,我们将群集数量大幅增加到大约KDTree
,但仍会获胜,但加速从100x
减少到{{1 }}:
15x
答案 1 :(得分:2)
我们可以将pdist
与一个循环 -
from scipy.spatial.distance import pdist
def cluster_data_pdist_v1(a, dist_thresh = 0.1):
d = pdist(a)
mask = d<=dist_thresh
n = len(a)
idx = np.concatenate(( [0], np.arange(n-1,0,-1).cumsum() ))
start, stop = idx[:-1], idx[1:]
idx_out = np.zeros(mask.sum(), dtype=int) # use np.empty for bit more speedup
cur_start = 0
for iterID,(i,j) in enumerate(zip(start, stop)):
if iterID not in idx_out[:cur_start]:
rm_idx = np.flatnonzero(mask[i:j])+iterID+1
L = len(rm_idx)
idx_out[cur_start:cur_start+L] = rm_idx
cur_start += L
return np.delete(a, idx_out[:cur_start], axis=0)
原创方法 -
def cluster_data_org(data, dist_thresh = 0.1):
for i in range(len(data)):
for j in range(i, len(data)):
if (i == j):
continue
elif np.sqrt((data[i, 0]-data[j, 0])**2 +
(data[i, 1]-data[j, 1])**2) <= 0.1:
data[j, 0] = np.nan
return data[~np.isnan(data).any(axis=1)]
运行时测试,验证范围为[0,1)
且10,000
点的随机数据 -
In [207]: np.random.seed(0)
...: data1 = np.random.rand(10000,2)
...: data2 = data1.copy()
...:
...: out1 = cluster_data_org(data1, dist_thresh = 0.1)
...: out2 = cluster_data_pdist_v1(data2, dist_thresh = 0.1)
...: print np.allclose(out1, out2)
True
In [208]: np.random.seed(0)
...: data1 = np.random.rand(10000,2)
...: data2 = data1.copy()
In [209]: %timeit cluster_data_org(data1, dist_thresh = 0.1)
1 loop, best of 3: 1min 50s per loop
In [210]: %timeit cluster_data_pdist_v1(data2, dist_thresh = 0.1)
1 loop, best of 3: 287 ms per loop
围绕 400x
加速进行此类设置!