从kdtree

时间:2016-08-11 10:16:01

标签: c nearest-neighbor kdtree

我正在编写一种算法,要求我搜索点的最近邻居。我从这篇文章(Using Google's C KD Tree Library)找到了kdtree库,但它没有从树中删除单个节点的功能。所以我开始实现自己的使用 www(dot)geeksforgeeks.org/k-dimensional-tree-set-3-delete / 作为模板。这一切都贯穿但不幸的是,有时节点会重复。 我的测试用例如下:

#include <stdio.h>
#include <assert.h>    
#include <stdlib.h>
#include <math.h>
#include <errno.h>
#include <string.h>
#include <stdarg.h>
#include "kdtree.h"

/* (hopefully) platform independent directory creation */
#if defined(_WIN32) || defined(WIN32)   /* this should be defined under windows, regardless of 64 or 32 bit*/
#include <direct.h>
#include <sys/stat.h>
#define GetWorkingDir _getcwd
#define MakeDir(str) _mkdir(str)
#else                                   /* unix based system */
#include <unistd.h>
#include <sys/stat.h>
#define GetWorkingDir getcwd
#define MakeDir(str) mkdir(str, 0777)
#endif

#ifndef MAX_PATH
#define MAX_PATH 260
#endif

void GetLogDir(char* strPath, int nBufSize)
{
    if(GetWorkingDir(strPath, nBufSize))
    {
        strncat(strPath, "/log/", 5);
        MakeDir(strPath);
    }
    else
    {
        fprintf(stderr, "Could not get working directory");
        exit(ENOENT);
    }
}

FILE* GetOpenFileHandle(const char* strFilenamePlusPath, const char* strOpenMode)
{
    if(strOpenMode == NULL)     // too bad we dont have default arguments in C :(
    {
            strOpenMode = "a+";
    }

    return(fopen(strFilenamePlusPath, strOpenMode));
}

int CloseFile(FILE* pFile)
{
    if(pFile != NULL)
    {
        fprintf(pFile, "\r\n"); // append a new line before closing!
        return(fclose(pFile));
    }

    fprintf(stderr, "Invalid file handle");
    exit(EFAULT);
}


void NodeLabelToFile(FILE* pFile, kdnode* node, const char* strName)
{
    fprintf(pFile, "%s [label=\"(%.3f, %.3f)\"] \n", strName, node->pos[0], node->pos[1]);
}

char* NodeToString(kdnode* node, int* num)
{
    char* strName = (char*) malloc(MAX_PATH);
    if(*num == 0)
    {
        sprintf(strName, "%s","root");
    }
    else
    {
        sprintf(strName, "node%d", *num);
    }
    return strName;
}


void NodesToFile(FILE* pFile, kdnode* node, const char* strParentname, int* num)
{
    if(node && pFile)
    {
        char* strLeft = NULL;
        char* strRight = NULL;

        if(node->left)
        {
            (*num)++;
            strLeft = NodeToString(node->left, num);
            NodeLabelToFile(pFile, node->left, strLeft);
            fprintf(pFile, "%s -> %s \n", strParentname, strLeft);
        }

        if(node->right)
        {
            (*num)++;
            strRight = NodeToString(node->right, num); // name of the current node
            NodeLabelToFile(pFile, node->right, strRight);
            fprintf(pFile, "%s -> %s \n", strParentname, strRight);
        }

        if(strLeft)
        {
            NodesToFile(pFile, node->left, strLeft, num);
            free(strLeft);
        }
        if(strRight)
        {
            // (*num)++;
            NodesToFile(pFile, node->right, strRight, num);
            free(strRight);
        }
    }
}


FILE* MakeOpenLogFile(const char* strFilename, const char* strOpenMode)
{
    if(strOpenMode == NULL)
    {
        strOpenMode = "a+";
    }

    char* strFilenamePlusPath = (char*) malloc(MAX_PATH);
    GetLogDir(strFilenamePlusPath, MAX_PATH);
    strncat(strFilenamePlusPath, strFilename, strlen(strFilename));
    FILE* pFile = GetOpenFileHandle(strFilenamePlusPath, strOpenMode);
    free(strFilenamePlusPath);
    return(pFile);
}

void KDTreeToDotFile(kdtree* Tree, const char* strFilename)
{
    if(Tree)
    {
        FILE* pFile = MakeOpenLogFile(strFilename, "w");

        fprintf(pFile, "%s", "digraph d { \n"); // print opening statement for the graph in dot language

        // traverse the tree and print the nodes
        int* num = (int*) malloc(sizeof(int));  // make this a unique location to make sure numbers can't occur twice

        *num = 0;
        char* strRoot = NodeToString(Tree->root, num);
        NodeLabelToFile(pFile, Tree->root, strRoot);
        NodesToFile(pFile, Tree->root, "root", num);

        if(strRoot)
        {
            free(strRoot);
        }
        free(num);
        fprintf(pFile,"%s", "}");            // close the digraph environment
        CloseFile(pFile);
    }
}

int main(int argc, const char * argv[])
{
    int numel = 20;
    int toRemove = 19;
    double dMax = 3000;
    int nNumDim = 2;

    printf("init rng");
    srand(1234); // seed the rng // srand((unsigned) time(&t));

    printf("creating kdtree");
    kdtree* TreeRoot = kd_create(nNumDim);  // construct the kd tree for the nearest neighbor search
    kd_data_destructor(TreeRoot, free); // set free as data destructor

    double* pos = (double*) malloc(nNumDim * numel * sizeof(double));
    int retval;

    for (int ii = 0; ii < numel; ii++)
    {
        pos[nNumDim * ii] = floor((double)rand()/(double)(RAND_MAX/dMax));
        pos[nNumDim * ii + 1] = floor((double)rand()/(double)(RAND_MAX/dMax));
        int* randint = (int*) malloc(sizeof(int));
        *randint = rand();
        retval = kd_insert2(TreeRoot,
                            pos[nNumDim * ii],
                            pos[nNumDim * ii + 1],
                            randint, sizeof(int));
        assert(retval == 0);
    }

    KDTreeToDotFile(TreeRoot, "original.dot");
    double* dRemovePos = (double*) malloc(sizeof(double)*nNumDim);
    for (int ii = 0; ii < toRemove; ii++)
    {
        dRemovePos[0] = pos[2*ii];
        dRemovePos[1] = pos[2*ii + 1];
        kd_remove(TreeRoot, dRemovePos);
    }
    KDTreeToDotFile(TreeRoot, "removed.dot");

    kd_free(TreeRoot);                  // free kdtree
    return 0;
}

并删除节点的功能如下所示: (我不认为是否代码太多,所以我只会将我的更改发布到kd库。如果我应该添加其余代码,不幸的是超过1000行,请在评论中告诉我。 )

int kd_remove(kdtree* tree, const double* pos)
{
    printf("removing node %.3f, %.3f \n", pos[0], pos[1]);
    if(tree->root != NULL)
    {
        assert(tree->dim != 0); // prevent division by 0 (error code 136)
        assert(pos != NULL);    // make sure a valid position is passed
        tree->root = remove_rec(tree->root, pos, tree->dim, tree->destr, 0);
    }
    return(0);
}

kdnode* remove_rec(kdnode* node, const double* pos, int dim, void (*destr)(void*), int depth)
{
    if(node == NULL)
    {
        return(NULL);
    }

    int curdim = depth % dim;

    if(same_pos(node->pos, pos, dim))
    {
        // we found the droid we're looking for
        if(node->right)
        {
            // find the minimum in the right subtree
            kdnode* node_min = find_min(node->right, curdim, dim);
            if(node_min)
            {
                copy_node_data(node_min, node, dim);
                node->right = remove_rec(node->right, node_min->pos, dim, destr, depth + 1);
            }
        }
        else if(node->left)
        {
            // find the minimum in the left subtree
            kdnode* node_min = find_min(node->left, curdim, dim);
            if(node_min)
            {
                copy_node_data(node_min, node, dim);
                node->left = remove_rec(node->left, node_min->pos, dim, destr, depth + 1);
            }
        }
        else
        {
            // no subtrees -> delete the found node
            clear_rec(node, destr);
            return(NULL);
        }
        return node;    // return the newly filled node to the recursion step one "above"
    }
    else
    {
        // points are not the same, look further
        if(pos[curdim] < node->pos[curdim])
        {
            // position we're looking for is smaller -> go left
            node->left = remove_rec(node->left, pos, dim, destr, depth + 1);
        }
        else
        {
            // go right, position we're looking for is greater
            node->right = remove_rec(node->right, pos, dim, destr, depth + 1);
        }
        return node;
    }
}

void copy_node_data(const kdnode* src, kdnode* dst, int dim)
{
    if(src && dst)
    {
        int nNumBytes = dim * sizeof(double);
        memcpy(dst->pos, src->pos, nNumBytes);

        if(dst->data != NULL)
        {
            free(dst->data);
            dst->data = malloc(src->databytes);
        }

        memcpy(dst->data, src->data, src->databytes);
        dst->databytes = src->databytes;
    }
}

int same_pos(const double* pos1, const double* pos2, int dim)
{
    for (int i = 0; i < dim; ++i)
    {
        if(pos1[i] != pos2[i])
        {
            return 0; // false
        }
    }
    return 1;   // true
}

kdnode* find_min(kdnode* node, int dir, int numdim)
{
    return find_min_rec(node, dir, 0, numdim);
}

kdnode* find_min_rec(kdnode* node, int dir, int depth, int numdim)
{
    if(!node)
    { 
        return NULL;
    }

    if(node->left == NULL && node->right == NULL)
    {
        return node; // is leaf node 
    }
    int curdim = depth % numdim;
    if(curdim == numdim)
    {
        if(node->left == NULL)
        {
            // no smaller node in tree
            return node;
        }
        else 
        {
            // left subtree is populated -> we need to go deeper
            return find_min_rec(node->left, node->dir, depth + 1, numdim);;
        }
    } 

    // we have to search both subtrees and find the smallest value compared to the current node
    return min_node(node,   find_min_rec(node->left, node->dir, depth + 1, numdim),
                            find_min_rec(node->right, node->dir, depth + 1, numdim), node->dir);
}

kdnode* min_node(kdnode* a, kdnode* left, kdnode* right, int dir)
{
    if(a == NULL)
    {
        // node a is the only one that can't be NULL!
        fprintf(stderr, "Error: invalid node passed! \n");
        exit(EFAULT);
    } 

    kdnode* result = a;

    if(left != NULL)
    {
        if(left->pos[dir] < result->pos[dir])
        {   
            result = left;
        }
    }

    if(right != NULL)
    {
        if(right->pos[dir] < result->pos[dir])
        {
            result = right;
        }
    }  

    return result;
}

original.dot看起来像这样,removed.dot就是这样。 我从昨天起就一直在调试这个,我觉得这里有一些非常明显的东西,我在这里失踪了...... 提前感谢愿意帮助的人:)

2 个答案:

答案 0 :(得分:1)

您正在创建40元素

int numel = 20;
int nNumDim = 2;

double* pos = (double*) malloc(nNumDim * numel * sizeof(double)); // Don't cast

但仅删除38

int toRemove = 19;

for (int ii = 0; ii < toRemove; ii++)
{
    dRemovePos[0] = pos[nNumDim * ii];
    dRemovePos[1] = pos[nNumDim * ii + 1];
    kd_remove(TreeRoot, dRemovePos);
}

在上一次迭代中:

pos[nNumDim * ii]; = pos[2 * 18]; = pos[36];

pos[nNumDim * ii + 1]; = pos[2 * 18 + 1]; = pos[37];

pos[38]pos[39]仍在那里。

更改为int toRemove = 20;

由于平面数组,您的代码被混淆了,为什么不声明类似

的类型
struct data {
    double el1;
    double el2;
};

typedef double data[2];

然后

data *value = malloc(numel * sizeof(*value));

答案 1 :(得分:0)

所以,我知道这可能不会被任何人阅读,但我在没有触及代码一段时间后发现了这个错误,为了完整性,这里是如何:

find_min()函数中,我使用depth = 0开始递归。 这可能导致拆分维度混乱,因此无法访问所有节点。 我修改了函数以depth作为参数并传递递归深度remove_rec(),如下所示:

kdnode* node_min = find_min(node->right, curdim, dim, depth + 1);

kdnode* node_min = find_min(node->left, curdim, dim, depth + 1);

分别