将值捕捉到某些网格的高效pythonic方法

时间:2011-12-10 15:38:50

标签: python performance algorithm

我有一维网格,定义为已排序的浮点值列表。这些点不是等距的,但保证没有任何一对冲突点(距离== 0)。

我需要找到将任何给定值捕捉到最近网格点的最有效方法。我能想到的最聪明的方式如下(npnumpymyGrid是一个numpy array

absDiff = np.abs(myGrid - myValue)
ix = np.argmax(absDiff)
snappedValue = myGrid[ix]

问题是这种方法太慢了,我需要一种更有效的方法。

4 个答案:

答案 0 :(得分:2)

import bisect
def snap(myGrid, myValue):
    ix = bisect.bisect_right(myGrid, myValue)
    if ix == 0:
        return myGrid[0]
    elif ix == len(myGrid):
        return myGrid[-1]
    else:
        return min(myGrid[ix - 1], myGrid[ix], key=lambda gridValue: abs(gridValue - myValue))

答案 1 :(得分:2)

我为numpy数组编写了一个函数:

def ndsnap(points, grid):
    """
    Snap an 2D-array of points to values along an 2D-array grid.
    Each point will be snapped to the grid value with the smallest
    city-block distance.

    Parameters
    ---------
    points: 2D-array. Must have same number of columns as grid
    grid: 2D-array. Must have same number of columns as points

    Returns
    -------
    A 2D-array with one row per row of points. Each i-th row will
    correspond to row of grid to which the i-th row of points is closest.
    In case of ties, it will be snapped to the row of grid with the
    smaller index.
    """
    grid_3d = np.transpose(grid[:,:,np.newaxis], [2,1,0])
    diffs = np.sum(np.abs(grid_3d - points[:,:,np.newaxis]), axis=1)
    best = np.argmin(diffs, axis=1)
    return grid[best,:]

答案 2 :(得分:0)

在典型情况下,您的新点将落在网格中的两个现有点之间。您需要使用二分搜索来找到它之间的两个点,并从两个中选择最接近的一个点。就是这样。

现在剩下的就是正确的过程边界情况:当点落在第一个/最后一个之后以及点到达网格的现有点时。

答案 3 :(得分:0)

要扩展Nolan Conaway的答案:

由于无论如何都使用了曼哈顿度量标准,因此对于完整的矩形网格,如果网格足够大,以下操作会更快:

def ndsnap_regular(points, *grid_axes):         
     snapped = []                                         
     for i, ax in enumerate(grid_axes):                                   
         diff = ax[:, np.newaxis] - points[:, i]
         best = np.argmin(np.abs(diff), axis=0)                                                                                                  
         snapped.append(ax[best])                                                                                           
     return np.array(snapped).T

在这里,grid_axes仅包含轴坐标的元组,numpy.meshgrid之类的函数将使用这些坐标来创建网格。

例如将100个3D点捕捉到50x50x50网格中:

>>> n_points = 100
>>> points = np.random.random((n_points, 3))
>>> grid = (np.linspace(0, 1), np.linspace(0, 1), np.linspace(0, 1))
>>> grid_expanded = cartesian_product(*grid)
>>> len(grid_expanded)
125000
>>> %timeit ndsnap(points, grid_expanded)
418 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit ndsnap_regular(points, *grid)
86.5 µs ± 922 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
>>> (ndsnap(points, grid_expanded) == ndsnap_regular(points, *grid)).all()
True