使用KD树查找范围内的所有邻居

时间:2013-04-15 19:24:20

标签: java data-structures kdtree dbscan

我正在尝试实施KD-tree以便与DBSCAN一起使用。问题是我需要找到满足距离标准的所有点的所有邻居。问题是,当我在我的实现中使用nearestNeighbours方法时,使用原始搜索(这是所需的输出)时,我得不到相同的输出。我的实现改编自python implementation。这是我到目前为止所得到的:

//Point.java
package dbscan_gui;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;

public class Point {

    final HashSet<Point> neighbours = new HashSet<Point>();
    int[] points;
    boolean visited = false;
    public Point(int... is) {
        this.points = is;
    }
    public String toString() {
        return Arrays.toString(points);
    }

    public double squareDistance(Point p) {
        double sum = 0;
        for (int i = 0;i < points.length;i++) {
            sum += Math.pow(points[i] - p.points[i],2);
        }
        return sum;
    }
    public double distance(Point p) {
        return Math.sqrt(squareDistance(p));
    }
    public void addNeighbours(ArrayList<Point> ps) {
        neighbours.addAll(ps);
    }
    public void addNeighbour(Point p) {
        if (p != this)
            neighbours.add(p);
    }
}

//KDTree.java
package dbscan_gui;


import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.TreeSet;


public class KDTree {
    KDTreeNode root;
    PointComparator[] comps;
    public KDTree(ArrayList<Point> list) {
        int axes = list.get(0).points.length;

        comps = new PointComparator[axes];
        for (int i = 0; i < axes; i++) {
            comps[i] = new PointComparator(i);
        }
        root = new KDTreeNode(list,0);
    }
    private class PointComparator implements Comparator<Point> {
        private int axis;
        public PointComparator(int axis) {
            this.axis = axis;
        }
        @Override
        public int compare(Point p1, Point p2) {
            return p1.points[axis] - p2.points[axis];
        } 
    }

    /**
     * Adapted from https://code.google.com/p/python-kdtree/
     * Stores points in a tree, sorted by axis
     */
    public class KDTreeNode {
        KDTreeNode leftChild = null;
        KDTreeNode rightChild = null;
        Point location;

        public KDTreeNode(ArrayList<Point> list, int depth) {
            if(list.isEmpty())
                return;
            final int axis = depth % (list.get(0).points.length);

            Collections.sort(list, comps[axis] );
            int median = list.size()/2;
            location = list.get(median);
            List<Point> leftPoints = list.subList(0, median);
            List<Point> rightPoints = list.subList(median+1, list.size());
            if(!leftPoints.isEmpty())
                leftChild  = new KDTreeNode(new ArrayList<Point>(leftPoints), depth+1);
            if(!rightPoints.isEmpty())
                rightChild = new KDTreeNode(new ArrayList<Point>(rightPoints),depth+1);
        }

        /**
         * @return true if this node has no children
         */
        public boolean isLeaf() {
            return leftChild == null && rightChild == null;
        }

    }
    /**
     * Finds the nearest neighbours of a point that fall within a given distance
     * @param queryPoint the point to find the neighbours of
     * @param epsilon the distance threshold
     * @return the list of points
     */
    public ArrayList<Point> nearestNeighbours(Point queryPoint, int epsilon) {
        KDNeighbours neighbours = new KDNeighbours(queryPoint);
        nearestNeighbours_(root, queryPoint, 0, neighbours);
        return neighbours.getBest(epsilon);
    }
    /**
     * @param node
     * @param queryPoint
     * @param depth
     * @param bestNeighbours
     */
    private void nearestNeighbours_(KDTreeNode node, Point queryPoint, int depth, KDNeighbours bestNeighbours) {
        if(node == null)
            return;
        if(node.isLeaf()) {
            bestNeighbours.add(node.location);
            return;
        }
        int axis = depth % (queryPoint.points.length);
        KDTreeNode nearSubtree = node.rightChild;
        KDTreeNode farSubtree  = node.leftChild;
        if(queryPoint.points[axis] < node.location.points[axis]) {
            nearSubtree = node.leftChild;
            farSubtree = node.rightChild;
        }
        nearestNeighbours_(nearSubtree, queryPoint,  depth+1, bestNeighbours);
        if(node.location != queryPoint)
            bestNeighbours.add(node.location);       
        if(Math.pow(node.location.points[axis] - queryPoint.points[axis],2) <= bestNeighbours.largestDistance)
            nearestNeighbours_(farSubtree, queryPoint, depth+1,bestNeighbours);
        return;
    }
    /**
     * Private datastructure for holding the neighbours of a point
     */
    private class KDNeighbours {
        Point queryPoint;
        double largetsDistance = 0;
        TreeSet<Tuple> currentBest = new TreeSet<Tuple>(new Comparator<Tuple>() {

            @Override
            public int compare(Tuple o1, Tuple o2) {
                return (int) (o1.y-o2.y);
            }

        });
        KDNeighbours(Point queryPoint) {
            this.queryPoint = queryPoint;
        }
        public ArrayList<Point> getBest(int epsilon) {
            ArrayList<Point> best = new ArrayList<Point>();
            Iterator<Tuple> it = currentBest.iterator();
            while(it.hasNext()) {
                Tuple t =it.next();
                if(t.y > epsilon*epsilon)
                    break;
                else if(t.x != queryPoint)
                    best.add(t.x);
            }
            return best;
        }

        public void add(Point p) {
            currentBest.add(new Tuple(p, p.squareDistance(queryPoint)));
            largestDistance = currentBest.last().y;
        }
        private class Tuple  {
            Point x;
            double y;
            Tuple(Point x, double y) {
                this.x = x;
                this.y = y;
            }
        }
    }

    public static void main(String[] args) {
        int epsilon = 3;

        System.out.println("Epsilon: "+epsilon);
        ArrayList<Point> points = new ArrayList<Point>();
        Random r = new Random();
        for (int i = 0; i < 10; i++) {
            points.add(new Point(r.nextInt(10), r.nextInt(10)));
        }
        System.out.println("Points "+points );
        System.out.println("----------------");
        System.out.println("Neighbouring Kd");
        KDTree tree = new KDTree(points);
        for (Point p : points) {
            ArrayList<Point> neighbours = tree.nearestNeighbours(p, epsilon);
            for (Point q : neighbours) {
                q.addNeighbour(p);
            }
            p.addNeighbours(neighbours);
            p.printNeighbours();
            p.neighbours.clear();
        }
        System.out.println("------------------");
        System.out.println("Neighbouring O(n^2)");
        for (int i = 0; i < points.size(); i++) {
            for (int j = i + 1; j < points.size(); j++) {
                Point p = points.get(i), q = points.get(j);
                if (p.distance(q) <= epsilon) {
                    p.addNeighbour(q);
                    q.addNeighbour(p);
                }
            }
        }
        for (Point point : points) {
            point.printNeighbours();
        }

    }
}

当我运行它时,我得到以下输出(后一部分是模型输出):

Epsilon: 3
Points [[9, 5], [4, 7], [3, 1], [0, 0], [5, 7], [0, 1], [5, 5], [1, 2], [9, 2], [9, 9]]
----------------
Neighbouring Kd
Neighbours of [0, 0] are: [[0, 1]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 7]]
Neighbours of [5, 7] are: [[4, 7]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []
------------------
Neighbouring O(n^2)
Neighbours of [0, 0] are: [[0, 1], [1, 2]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [0, 0], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 5], [5, 7]]
Neighbours of [5, 7] are: [[4, 7], [5, 5]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []

我无法弄清楚为什么邻居不一样,似乎它可以发现a-&gt; b是一个邻居,但不是b-&gt; a也是一个邻居。

1 个答案:

答案 0 :(得分:-1)

您可能希望使用包含DBSCAN的ELKI和索引结构,例如最近邻搜索的R * -tree。当参数化正确时,它真的非常快。我在trac中看到下一个版本也将有一个KD树。

从快速查看代码,我必须同意@ThomasJungblut - 你没有回溯,然后根据需要尝试其他分支,这就是为什么你错过了很多邻居。 您可能需要查看两个分支机构!