TL;博士 我需要一种方法来使用KDTree或其他一些空间数据结构来找到“外国最近邻居”。即在树的子集中找到最近的邻居。
我构建了一个使用KDTree查找最近邻居的MST算法。然而,最终它需要超越最近邻居并进入“最近的外国邻居”以连接远程节点。我的第一种方法是简单地迭代地增加k-nn参数,直到查询返回子集中的节点。我在每次调用函数时都缓存k,其搜索的广度被扩展,并且搜索先前的k< k_cache。
def FNNd(kdtree, A, b):
"""
kdtree -> nodes in subnet -> coord of b -> index of a
returns nearest foreign neighbor a∈A of b
"""
a = None
b = cartesian_projection(b)
k = k_cache[str(b)] if str(b) in k_cache else 2
while a not in A:
#scipy kdtree where query -> [dist], [idx]
_, nn = kdtree.query(b, k=k)
a = nn[-1][k-1]
k += 1
k_cache[str(b)] = k-1
#return NN a ∈ A of b
return a
然而,这非常'hacky'且效率低下,所以我想我可以自己实现一个KDTree,当这样做会导致不包含受限子集的子树时停止遍历。然后子集中的最近邻居必须是左或右分支。经过多次尝试,我似乎无法实现这一点。我的逻辑中有缺陷吗?更好的方法吗?更好的数据结构?
继承我的KDTree
class KDTree(object):
def __init__(self, data, depth=0, make_idx=True):
self.n, self.k = data.shape
if make_idx:
# index the data
data = np.column_stack((data, np.arange(self.n)))
else:
# subtract the indexed dimension in later calls
self.k -= 1
self.build(data, depth)
def build(self, data, depth):
if data.size > 0:
# get the axis to pivot on
self.axis = depth % self.k
# sort the data
s_data = data[np.argsort(data[:, self.axis])]
# find the pivot point
point = s_data[len(s_data) // 2]
# point coord
self.point = point[:-1]
# point index
self.idx = int(point[-1])
# all nodes below this node
self.children = s_data[np.all(s_data[:, :-1] != self.point, axis=1)]
# branches
self.left = KDTree(s_data[: len(s_data) // 2], depth+1, False)
self.right = KDTree(s_data[len(s_data) // 2 + 1: ], depth+1, False)
else:
# empty node
self.axis=0
self.point = self.idx = self.left = self.right = None
self.children = np.array([])
def query(self, point, best=None):
if self.point is None:
return best
if best is None:
best = (self.idx, self.point)
# check if current node is closer than best
if distance(self.point, point) < distance(best[1], point):
best = (self.idx, self.point)
# continue traversing the tree
best = self.near_tree(point).query(point, best)
# traverse the away branch if the orthogonal distance is less than best
if self.orthongonal_dist(point) < distance(best[1], point):
best = self.away_tree(point).query(point, best)
return best
def orthongonal_dist(self, point):
orth_point = np.copy(point)
orth_point[self.axis] = self.point[self.axis]
return distance(point, self.point)
def near_tree(self, point):
if point[self.axis] < self.point[self.axis]:
return self.left
return self.right
def away_tree(self, point):
if self.near_tree(point) == self.left:
return self.right
return self.left
[编辑]更新了尝试,但这不保证返回
def query_subset(self, point, subset, best=None):
# if point in subset, update best
if self.idx in subset:
# if closer than current best, or best is none update
if best is None or distance(self.point, point) < distance(best[1], point):
best = (self.idx, self.point)
# Dead end backtrack up the tree
if self.point is None:
return best
near = self.near_tree(point)
far = self.away_tree(point)
# what nodes are in the near branch
if near.children.size > 1:
near_set = set(np.append(near.children[:, -1], near.idx))
else: near_set = {near.idx}
# check the near branch, if its nodes intersect with the queried subset
# otherwise move to the away branch
if any(x in near_set for x in subset):
best = near.query_subset(point, subset, best)
else:
best = far.query_subset(point, subset, best)
# validate best, by ensuring closer point doesn't exist just beyond partition
if best is not None:
if self.orthongonal_dist(point) < distance(best[1], point):
best = far.query_subset(point, subset, best)
return best