具有指定范围的最近邻1维数据

时间:2013-09-15 20:12:45

标签: python algorithm numpy scipy nearest-neighbor

我有两个嵌套列表A和B:

A = [[50,140],[51,180],[54,500],......]

B = [[50.1, 170], [51,200],[55,510].....]

每个内部列表中的第一个元素从0到大约1e5,第0个元素从大约50到大约700,这些元素是未排序的。我想要做的是,运行A [n] [1]中的每个元素并找到B [n] [1]中最接近的元素,但是当搜索最近的邻居时,我想仅在由A [n] [0]加或减0.5。

我一直在使用这个功能:

def find_nearest_vector(array, value): 
   idx = np.array([np.linalg.norm(x+y) for (x,y) in array-value]).argmin()
   return array[idx]

例如,它找到坐标A[0][:]B[0][:]之间的最近邻居。但是,我需要将搜索范围限制在值A [0] [0]的某个小移位周围的矩形。另外,我不想重复使用元素 - 我希望在区间A [n] [0] +/- 0.5内的每个值A [n] [1]到B [n] [1]之间有一个独特的双射。

我一直在尝试使用Scipy的KDTree,但这会重用元素,我不知道如何限制搜索范围。实际上,我想在沿着特定轴的二维嵌套列表上进行一维NNN搜索,其中NNN搜索在由每个内部列表中的第0个元素定义的超矩形内的邻域加上或减去一些小的移位

2 个答案:

答案 0 :(得分:2)

我使用numpy.argsort()numpy.searchsorted()numpy.argmin()进行搜索。

%pylab inline
import numpy as np
np.random.seed(0)
A = np.random.rand(5, 2)
B = np.random.rand(100, 2)
xaxis_range = 0.02
order = np.argsort(B[:, 0])
bx = B[order, 0]
sidx = np.searchsorted(bx, A[:, 0] - xaxis_range, side="right")
eidx = np.searchsorted(bx, A[:, 0] + xaxis_range, side="left")
result = []
for s, e, ay in zip(sidx, eidx, A[:, 1]):
    section = order[s:e]
    by = B[section, 1]
    idx = np.argmin(np.abs(ay-by))
    result.append(B[section[idx]])
result = np.array(result)

我将结果绘制如下:

plot(A[:, 0], A[:, 1], "o")
plot(B[:, 0], B[:, 1], ".")
plot(result[:, 0], result[:, 1], "x")

输出:

enter image description here

答案 1 :(得分:0)

我对您的问题的理解是,您试图在另一组点中找到每个A[n][1]最接近的元素(B[i][1]仅限于A[n][0]位于+内的点{/ 1} - B[i][0])0.5。

我不熟悉numpy或scipy,我确信使用他们的算法有更好的方法。

话虽如此,这是我在O(a*b*log(a*b))时间的天真实施。

def main(a,b):
    for a_bound,a_val in a:
        dist_to_valid_b_points = {abs(a_val-b_val):(b_bound,b_val) for b_bound,b_val in b if are_within_bounds(a_bound,b_bound)}
        print get_closest_point((a_bound, a_val),dist_to_valid_b_points)

def are_within_bounds(a_bound, b_bound):
    return abs(b_bound-a_bound) < 0.5

def get_closest_point(a_point, point_dict):
    return (a_point, None if not point_dict else point_dict[min(point_dict, key=point_dict.get)])

main([[50,140],[51,180],[54,500]],[[50.1, 170], [51,200],[55,510]])产生以下输出:

((50, 140), (50.1, 170))
((51, 180), (51, 200))
((54, 500), None)