k-d树实现

时间:2012-03-12 19:34:53

标签: c++ stl vector tree

我有以下代码尝试实施和分析,

#include <iostream.h>
    #include "vector.h"

    /**
     * Quick illustration of a two-dimensional tree.
     * No abstraction here.
     */
    template <class Comparable>
    class KdTree
    {
      public:
        KdTree( ) : root( NULL ) { }

        void insert( const vector<Comparable> & x )
        {
            insert( x, root, 0 );
        }

        /**
         * Print items satisfying
         * low[ 0 ] <= x[ 0 ] <= high[ 0 ] and
         * low[ 1 ] <= x[ 1 ] <= high[ 1 ]
         */
        void printRange( const vector<Comparable> & low, 
                         const vector<Comparable> & high ) const
        {
            printRange( low, high, root, 0 );
        }

      private:
        struct KdNode
        {
            vector<Comparable> data;
            KdNode            *left;
            KdNode            *right;

            KdNode( const vector<Comparable> & item )
              : data( item ), left( NULL ), right( NULL ) { }
        };

        KdNode *root;

        void insert( const vector<Comparable> & x, KdNode * & t, int level )
        {
            if( t == NULL )
                t = new KdNode( x );
            else if( x[ level ] < t->data[ level ] )
                insert( x, t->left, 1 - level );
            else
                insert( x, t->right, 1 - level );
        }


        void printRange( const vector<Comparable> & low,
                         const vector<Comparable> & high,
                         KdNode *t, int level ) const
        {
            if( t != NULL )
            {
                if( low[ 0 ] <= t->data[ 0 ] && high[ 0 ] >= t->data[ 0 ] && 
                    low[ 1 ] <= t->data[ 1 ] && high[ 1 ] >= t->data[ 1 ] )
                    cout << "(" << t->data[ 0 ] << "," 
                                << t->data[ 1 ] << ")" << endl;

                if( low[ level ] <= t->data[ level ] )
                    printRange( low, high, t->left, 1 - level );
                if( high[ level ] >= t->data[ level ] )
                    printRange( low, high, t->right, 1 - level );
            }
        }
    };

        // Test program
        int main( )
        {
            KdTree<int> t;

            cout << "Starting program" << endl;
            for( int i = 300; i < 370; i++ )
            {
                vector<int> it( 2 );
                it[ 0 ] =  i;
                it[ 1 ] =  2500 - i;
                t.insert( it );
            }

            vector<int> low( 2 ), high( 2 );
            low[ 0 ] = 70;
            low[ 1 ] = 2186;
            high[ 0 ] = 1200;
            high[ 1 ] = 2200;

            t.printRange( low, high );

            return 0;
        }

问题在于,这里的矢量类从源代码描述起来非常困难,所以我想使用现有的c ++ STL矢量,但不知道该怎么做,请帮帮我,比如如何在插入程序中使用vector?等等,请

1 个答案:

答案 0 :(得分:3)

你的代码已经与STL兼容了:我刚刚更改了标题,主要是为了提高可读性,引入了一个typedef:

#include <iostream>
#include <vector>
using namespace std;

/**
* Quick illustration of a two-dimensional tree.
* No abstraction here.
*/
template <class Comparable>
class KdTree
{
public:
    typedef vector<Comparable> tVec;

    KdTree( ) : root( NULL ) { }

    void insert( const tVec & x )
    {
        insert( x, root, 0 );
    }

    /**
    * Print items satisfying
    * low[ 0 ] <= x[ 0 ] <= high[ 0 ] and
    * low[ 1 ] <= x[ 1 ] <= high[ 1 ]
    */
    void printRange( const tVec & low,
                    const tVec & high ) const
    {
        printRange( low, high, root, 0 );
    }

private:
    struct KdNode
    {
        tVec data;
        KdNode            *left;
        KdNode            *right;

        KdNode( const tVec & item )
            : data( item ), left( NULL ), right( NULL ) { }
    };

    KdNode *root;

    void insert( const tVec & x, KdNode * & t, int level )
    {
        if( t == NULL )
            t = new KdNode( x );
        else if( x[ level ] < t->data[ level ] )
            insert( x, t->left, 1 - level );
        else
            insert( x, t->right, 1 - level );
    }


    void printRange( const tVec & low,
                    const tVec & high,
                    KdNode *t, int level ) const
    {
        if( t != NULL )
        {
            if( low[ 0 ] <= t->data[ 0 ] && high[ 0 ] >= t->data[ 0 ] &&
                    low[ 1 ] <= t->data[ 1 ] && high[ 1 ] >= t->data[ 1 ] )
                cout << "(" << t->data[ 0 ] << ","
                     << t->data[ 1 ] << ")" << endl;

            if( low[ level ] <= t->data[ level ] )
                printRange( low, high, t->left, 1 - level );
            if( high[ level ] >= t->data[ level ] )
                printRange( low, high, t->right, 1 - level );
        }
    }
};

// Test program
int main_kdtree(int, char **)
{
    typedef KdTree<int> tTree;
    tTree t;

    cout << "Starting program" << endl;
    for( int i = 300; i < 370; i++ )
    {
        tTree::tVec it( 2 );
        it[ 0 ] =  i;
        it[ 1 ] =  2500 - i;
        t.insert( it );
    }

    tTree::tVec low( 2 ), high( 2 );
    low[ 0 ] = 70;
    low[ 1 ] = 2186;
    high[ 0 ] = 1200;
    high[ 1 ] = 2200;

    t.printRange( low, high );

    return 0;
}

输出:

Starting program
(300,2200)
(301,2199)
....
(313,2187)
(314,2186)