如何将KDTree查询限制为节点的子集?

时间:2014-10-28 13:41:25

标签: python data-structures computational-geometry nearest-neighbor minimum-spanning-tree

TL;博士 我需要一种方法来使用KDTree或其他一些空间数据结构来找到“外国最近邻居”。即在树的子集中找到最近的邻居。

我构建了一个使用KDTree查找最近邻居的MST算法。然而,最终它需要超越最近邻居并进入“最近的外国邻居”以连接远程节点。我的第一种方法是简单地迭代地增加k-nn参数,直到查询返回子集中的节点。我在每次调用函数时都缓存k,其搜索的广度被扩展,并且搜索先前的k< k_cache。

def FNNd(kdtree, A, b):
"""
kdtree -> nodes in subnet -> coord of b -> index of a
returns nearest foreign neighbor a∈A of b
"""
a = None
b = cartesian_projection(b)
k = k_cache[str(b)] if str(b) in k_cache else 2

while a not in A:
    #scipy kdtree where query -> [dist], [idx]
    _, nn = kdtree.query(b, k=k)
    a = nn[-1][k-1]
    k += 1

k_cache[str(b)] = k-1
#return NN a ∈ A of b 
return a

然而,这非常'hacky'且效率低下,所以我想我可以自己实现一个KDTree,当这样做会导致不包含受限子集的子树时停止遍历。然后子集中的最近邻居必须是左或右分支。经过多次尝试,我似乎无法实现这一点。我的逻辑中有缺陷吗?更好的方法吗?更好的数据结构?

继承我的KDTree

class KDTree(object):    
def __init__(self, data, depth=0, make_idx=True):
    self.n, self.k = data.shape

    if make_idx:
        # index the data
        data = np.column_stack((data, np.arange(self.n)))
    else:
        # subtract the indexed dimension in later calls
        self.k -= 1

    self.build(data, depth)

def build(self, data, depth):

    if data.size > 0:
        # get the axis to pivot on
        self.axis = depth % self.k
        # sort the data
        s_data = data[np.argsort(data[:, self.axis])]
        # find the pivot point
        point = s_data[len(s_data) // 2]

        # point coord
        self.point = point[:-1]
        # point index
        self.idx = int(point[-1])

        # all nodes below this node
        self.children = s_data[np.all(s_data[:, :-1] != self.point, axis=1)]
        # branches
        self.left  = KDTree(s_data[: len(s_data) // 2], depth+1, False)
        self.right = KDTree(s_data[len(s_data) // 2 + 1: ], depth+1, False)
    else:
        # empty node
        self.axis=0
        self.point = self.idx = self.left = self.right = None
        self.children = np.array([])

def query(self, point, best=None):

    if self.point is None:
        return best

    if best is None:
        best = (self.idx, self.point)

    # check if current node is closer than best
    if distance(self.point, point) < distance(best[1], point):
        best = (self.idx, self.point)

    # continue traversing the tree
    best = self.near_tree(point).query(point, best)

    # traverse the away branch if the orthogonal distance is less than best
    if self.orthongonal_dist(point) < distance(best[1], point):
        best = self.away_tree(point).query(point, best)    
    return best 

 def orthongonal_dist(self, point):
    orth_point = np.copy(point)
    orth_point[self.axis] = self.point[self.axis]
    return distance(point, self.point)

def near_tree(self, point):
    if point[self.axis] < self.point[self.axis]:
        return self.left
    return self.right

def away_tree(self, point):
    if self.near_tree(point) == self.left:
        return self.right
    return self.left

[编辑]更新了尝试,但这不保证返回

def query_subset(self, point, subset, best=None):

    # if point in subset, update best
    if self.idx in subset:
        # if closer than current best, or best is none update
        if best is None or distance(self.point, point) < distance(best[1], point):
            best = (self.idx, self.point)

    # Dead end backtrack up the tree
    if self.point is None:
        return best

    near = self.near_tree(point)
    far = self.away_tree(point)

    # what nodes are in the near branch
    if near.children.size > 1:
        near_set = set(np.append(near.children[:, -1], near.idx))
    else: near_set = {near.idx}

    # check the near branch, if its nodes intersect with the queried subset
    # otherwise move to the away branch
    if any(x in near_set for x in subset):
        best = near.query_subset(point, subset, best)
    else:
        best = far.query_subset(point, subset, best)

    # validate best, by ensuring closer point doesn't exist just beyond partition
    if best is not None:
        if self.orthongonal_dist(point) < distance(best[1], point):
            best = far.query_subset(point, subset, best)    

    return best 

0 个答案:

没有答案