加速Numpy / Python中的数组查询

时间:2012-02-16 20:37:19

标签: python arrays performance numpy vectorization

我有一个点数组(称为 points ),由~30000 x,y和z值组成。我还有一个单独的点数组(称为顶点),约为40000 x,y和z值。后一个数组索引边长大小的一些立方体的左下角。我想找出哪些点位于哪些立方体中,以及每个立方体中有多少个点。我写了一个循环来做这个,它的工作原理如下:

for i in xrange(len(vertices)):        
    cube=((vertices[i,0]<= points[:,0]) & 
    (points[:,0]<(vertices[i,0]+size)) & 
    (vertices[i,1]<= points[:,1]) & 
    (points[:,1] < (vertices[i,1]+size)) &
    (vertices[i,2]<= points[:,2]) & 
    (points[:,2] < (vertices[i,2]+size))
    )
    numpoints[i]=len(points[cube])

(循环是对各个立方体进行排序,“立方体”创建一个布尔数量的索引。)然后我将点[cube]存储在某处,但这并不会减慢我的速度;这是“立方体=”的创造。

我想加快这个循环(在macbook pro上完成需要几十秒)。我尝试在C中重写“cube =”部分,如下所示:

for i in xrange(len(vertices)):
    cube=zeros(pp, dtype=bool)
    code="""
            for (int j=0; j<pp; ++j){

                if (vertices(i,0)<= points(j,0))
                 if (points(j,0) < (vertices(i,0)+size))
                  if (vertices(i,1)<= points(j,1))
                   if (points(j,1) < (vertices(i,1)+size))
                    if (vertices(i,2)<= points(j,2))
                     if (points(j,2) < (vertices(i,2)+size))
                      cube(j)=1;
            }
        return_val = 1;"""

    weave.inline(code,
    ['vertices', 'points','size','pp','cube', 'i']) 
    numpoints[i]=len(points[cube])

这加快了一点两倍。在C中重写两个循环实际上使它仅比原始的numpy-only版本略快,因为频繁引用了跟踪哪些点在哪些多维数据集中所必需的数组对象。我怀疑它可以更快地做到这一点,而且我错过了一些东西。谁能建议如何加快速度?我是numpy / python的新手,并提前感谢。

1 个答案:

答案 0 :(得分:3)

您可以使用scipy.spatial.cKDTree来加速此类计算。

以下是代码:

import time
import numpy as np

#### create some sample data ####
np.random.seed(1)

V_NUM = 6000
P_NUM = 8000

size = 0.1

vertices = np.random.rand(V_NUM, 3)
points = np.random.rand(P_NUM, 3)

numpoints = np.zeros(V_NUM, np.int32)

#### brute force ####
start = time.clock()
for i in xrange(len(vertices)):        
    cube=((vertices[i,0]<= points[:,0]) & 
    (points[:,0]<(vertices[i,0]+size)) & 
    (vertices[i,1]<= points[:,1]) & 
    (points[:,1] < (vertices[i,1]+size)) &
    (vertices[i,2]<= points[:,2]) & 
    (points[:,2] < (vertices[i,2]+size))
    )
    numpoints[i]=len(points[cube])

print time.clock() - start

#### KDTree ####
from scipy.spatial import cKDTree
center_vertices = vertices + [[size/2, size/2, size/2]]
start = time.clock()
tree_points = cKDTree(points)
_, result = tree_points.query(center_vertices, k=100, p = np.inf, distance_upper_bound=size/2)
numpoints2 = np.zeros(V_NUM, np.int32)
for i, neighbors in enumerate(result):
    numpoints2[i] = np.sum(neighbors!=P_NUM)

print time.clock() - start
print np.all(numpoints == numpoints2)
  • 首先将立方角位置更改为中心位置。

center_vertices = vertices + [[size/2, size/2, size/2]]

  • 从积分
  • 创建cKDTree

tree_points = cKDTree(points)

  • 查询,k是要返回的最近邻居的数量,p = np.inf表示最大坐标差距,distance_upper_bound是最大距离。

_, result = tree_points.query(center_vertices, k=100, p = np.inf, distance_upper_bound=size/2)

输出结果为:

2.04113164434
0.11087783696
True

如果多维数据集中有超过100个点,您可以在for循环中通过neighbors[-1] == P_NUM进行检查,并对这些顶点执行k = 1000查询。