加速Python cKDTree

时间:2018-05-08 03:04:16

标签: python performance for-loop vectorization nearest-neighbor

我目前有一个我创建的函数,它将蓝点与其(最多)3个最近邻居连接在55的像素范围内.vertices_xy_list是一个非常大的列表或点(嵌套列表)约5000-10000对

vertices_xy_list的示例:

[[3673.3333333333335, 2483.3333333333335],
 [3718.6666666666665, 2489.0],
 [3797.6666666666665, 2463.0],
 [3750.3333333333335, 2456.6666666666665],...]

我目前编写了这个calculate_draw_vertice_lines()函数,它在While循环中使用CKDTree来查找55像素内的所有点,然后用绿线连接它们。

可以看出,随着列表变长,这会呈指数级变慢。有没有什么方法可以显着加快这个功能?如矢量化操作?

def calculate_draw_vertice_lines():

    global vertices_xy_list
    global cell_wall_lengths
    global list_of_lines_references

    index = 0

    while True:

        if (len(vertices_xy_list) == 1):

            break

        point_tree = spatial.cKDTree(vertices_xy_list)

        index_of_closest_points = point_tree.query_ball_point(vertices_xy_list[index], 55)

        index_of_closest_points.remove(index)

        for stuff in index_of_closest_points:

            list_of_lines_references.append(plt.plot([vertices_xy_list[index][0],vertices_xy_list[stuff][0]] , [vertices_xy_list[index][1],vertices_xy_list[stuff][1]], color = 'green'))

            wall_length = math.sqrt( (vertices_xy_list[index][0] - vertices_xy_list[stuff][0])**2 + (vertices_xy_list[index][1] - vertices_xy_list[stuff][1])**2 )

            cell_wall_lengths.append(wall_length)

        del vertices_xy_list[index]

    fig.canvas.draw()

enter image description here

1 个答案:

答案 0 :(得分:2)

如果我理解正确选择绿线的逻辑,则无需在每次迭代时创建KDTree。对于蓝点的每对(p1,p2),当且仅当以下保持时,应绘制该线:

  1. p1是p2的3个最近邻居之一。
  2. p2是p1的3个最近邻居之一。
  3. dist(p1,p2)< 55。
  4. 您可以创建一次KDTree并有效地创建绿线列表。下面是实现的一部分,它返回一系列索引对的列表,在这些索引之间需要绘制绿线。我的机器上的运行时间约为0.5秒,持续10,000点。

    import numpy as np
    from scipy import spatial
    
    
    data = np.random.randint(0, 1000, size=(10_000, 2))
    
    def get_green_lines(data):
        tree = spatial.cKDTree(data)
        # each key in g points to indices of 3 nearest blue points
        g = {i: set(tree.query(data[i,:], 4)[-1][1:]) for i in range(data.shape[0])}
    
        green_lines = list()
        for node, candidates in g.items():
            for node2 in candidates:
                if node2 < node:
                    # avoid double-counting
                    continue
    
                if node in g[node2] and spatial.distance.euclidean(data[node,:], data[node2,:]) < 55:
                    green_lines.append((node, node2))
    
        return green_lines
    

    您可以按如下方式绘制绿线:

    green_lines = get_green_lines(data)
    fig, ax = plt.subplots()
    ax.scatter(data[:, 0], data[:, 1], s=1)
    from matplotlib import collections as mc
    lines = [[data[i], data[j]] for i, j in green_lines]
    line_collection = mc.LineCollection(lines, color='green')
    ax.add_collection(line_collection)
    

    示例输出:

    enter image description here