查找特定距离内的所有最近邻居

时间:2015-09-06 14:27:55

标签: python numpy nearest-neighbor

我有一个大的x和y坐标列表,存储在numpy数组中。

Coordinates = [[ 60037633 289492298]
 [ 60782468 289401668]
 [ 60057234 289419794]]
...
...

我想要的是找到特定距离内的所有最近邻居(比方说3米)并存储结果,以便我以后可以对结果进行进一步的分析。

对于大多数包裹,我发现有必要确定应该找到多少个NN,但我只想在设定距离内完成所有NN。

我怎样才能实现这样的目标?什么是实现大型数据集(百万分)的最快和最好的方式?

1 个答案:

答案 0 :(得分:15)

您可以使用scipy.spatial.cKDTree

/* getfile client */ 
#include <stdio.h>      /* printf and standard I/O */ 
#include <sys/socket.h> /* socket, connect, socklen_t */ 
#include <arpa/inet.h>  /* sockaddr_in, inet_pton */ 
#include <string.h>     /* strlen */ 
#include <stdlib.h>     /* atoi */ 
#include <fcntl.h>      /* O_WRONLY, O_CREAT */ 
#include <unistd.h>     /* close, write, read */ 

#define SRV_PORT 5105 
#define MAX_RECV_BUF 256 
#define MAX_SEND_BUF 256 

int recv_file(int ,char*); 
int main(int argc, char* argv[]) 
{ 
   int sock_fd;  
   struct sockaddr_in  srv_addr;    
   if (argc < 3)  
   {    
      printf("usage: %s <filename> <IP address> [port number]\n", argv[0]);
       exit(EXIT_FAILURE);  
   }        
     memset(&srv_addr, 0, sizeof(srv_addr)); /* zero-fill srv_addr structure*/   
      /* create a client socket */  
     sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);    
     srv_addr.sin_family = AF_INET;  /* internet address family */    
     /* convert command line argument to numeric IP */  
     if ( inet_pton(AF_INET, argv[2], &(srv_addr.sin_addr)) < 1 ) 
      {    printf("Invalid IP address\n"); 
           exit(EXIT_FAILURE);  
      }   
     /* if port number supplied, use it, otherwise use SRV_PORT */  
     srv_addr.sin_port = (argc > 3) ? htons(atoi(argv[3])) : htons(SRV_PORT);     
     if( connect(sock_fd, (struct sockaddr*) &srv_addr, sizeof(srv_addr)) < 0 )  
     {    
       perror("connect error");    
       exit(EXIT_FAILURE); 
     }   
       printf("connected to:%s:%d ..\n",argv[2],SRV_PORT);   
       recv_file(sock_fd, argv[1]); /* argv[1] = file name */   
        /* close socket*/   
         if(close(sock_fd) < 0)  
          {
             perror("socket close error");   
             exit(EXIT_FAILURE);  
          }    
          return 0; 
} 
int recv_file(int sock, char* file_name) 
{ 
 char send_str [MAX_SEND_BUF]; /* message to be sent to server*/ 
  int f; /* file handle for receiving file*/  
  ssize_t sent_bytes, rcvd_bytes; 
  int recv_count,rcvd_file_size; /* count of recv() calls*/  
  char recv_str[MAX_RECV_BUF]; /* buffer to hold received data */  
  size_t send_strlen; /* length of transmitted string */    
  sprintf(send_str, "%s\n", file_name); /* add CR/LF (new line) */  
  send_strlen = strlen(send_str); /* length of message to be transmitted */   
  if( (sent_bytes = send(sock, file_name, send_strlen, 0)) < 0 ) 
  {   
   perror("send error");    
    return -1;  
  }  
   /* attempt to create file to save received data. 0644 = rw-r--r-- */  
   if ( (f = open(file_name, O_WRONLY|O_CREAT, 0644)) < 0 )  
   {    
   perror("error creating file");    
   return -1;  
   }   
    recv_count = 0; /* number of recv() calls required to receive the file */  
    rcvd_file_size = 0; /* size of received file */    /* continue receiving until ? (data or close) */  
    while ( (rcvd_bytes = recv(sock, recv_str, MAX_RECV_BUF, 0)) > 0 )  
    {    
       recv_count++;    
       rcvd_file_size += rcvd_bytes;       
        if (write(f, recv_str, rcvd_bytes) < 0 )   
         {     
          perror("error writing to file");     
          return -1;    
         }  
   }  
   close(f); /* close file*/  
   printf("Client Received: %d bytes in %d recv(s)\n", rcvd_file_size, recv_count); 
   return rcvd_file_size; 
}

这是一个展示你如何做的例子 通过一次调用找到一组点的所有最近邻居 到import numpy as np import scipy.spatial as spatial points = np.array([(1, 2), (3, 4), (4, 5)]) point_tree = spatial.cKDTree(points) # This finds the index of all points within distance 1 of [1.5,2.5]. print(point_tree.query_ball_point([1.5, 2.5], 1)) # [0] # This gives the point in the KDTree which is within 1 unit of [1.5, 2.5] print(point_tree.data[point_tree.query_ball_point([1.5, 2.5], 1)]) # [[1 2]] # More than one point is within 3 units of [1.5, 1.6]. print(point_tree.data[point_tree.query_ball_point([1.5, 1.6], 3)]) # [[1 2] # [3 4]]

point_tree.query_ball_point

enter image description here