快速评估产品

时间:2013-08-13 13:10:02

标签: python numpy

我想计算大量点位置到一组预定位置的最小距离(问题不是空间的,空间的大小是> 60)。我已经使用了cKDTree,但是为了避免使用scipy,我想知道是否有一种聪明的方法来计算numpy数组。循环很容易:

point_locs = # Shape n_dims * n_samples
test_points = # Shape n_dims * n_points
min_dist = np.zeros ( n_points )
for i in n_points:
   min_dist[i] = np.sum(( point_locs - test_points[:,i])**2,axis=1).argmin()

这有什么更快的吗?通常,n_points大约为10 ^ 5-10 ^ 7。

2 个答案:

答案 0 :(得分:3)

来自scipy.spatial.KDTree的文档:

  

对于大尺寸(20已经很大),不要指望它比蛮力跑得快得多。高维最近邻查询是计算机科学中一个重要的开放问题。

因此,不仅是计算机科学中的一个开放性问题,你的强力方法很可能是一个适当的次优选择。如果您可以利用数据中的某些已知结构,即所有点都属于n个已知空间区域之一,那么您可以将问题分解下来。

答案 1 :(得分:1)

您的代码不是有效的Python,我认为您的形状令人困惑......除此之外,如果您有足够的内存,您可以通过使用广播来矢量化距离计算来摆脱循环。如果您有这些数据:

n_set_points = 100
n_test_points = 10000
n_dims = 60

set_points = np.random.rand(n_set_points, n_dims) 
test_points = np.random.rand(n_test_points, n_dims)

然后这是最直接的计算:

# deltas.shape = (n_set_points, n_test_point, n_dims)
deltas = (set_points[:, np.newaxis, :] -
          test_points[np.newaxis, ...])

# dist[j, k] holds the squared distance between the
# j-th set_point and the k-th test point
dist = np.sum(deltas*deltas, axis=-1)

# nearest[j] is the index of the set_point closest to
# each test_point, has shape (n_test_points,)
nearest = np.argmin(dist, axis=0)

交易破坏者是否可以将deltas存储在内存中:它可以是一个庞大的阵列。如果你这样做,通过更加神秘的距离计算可以获得一些性能,但效率更高:

dist = np.einsum('jkd,jkd->jk', deltas, deltas)

如果deltas太大,请将test_points分解为可管理的块,并循环遍历块,例如:

def nearest_neighbor(set_pts, test_pts, chunk_size):
    n_test_points = len(test_pts)
    ret = np.empty((n_test_points), dtype=np.intp)

    for chunk_start in xrange(0, n_test_points ,chunk_size):
        deltas = (set_pts[:, np.newaxis, :] -
                  test_pts[np.newaxis,
                           chunk_start:chunk_start + chunk_size, :])
        dist = np.einsum('jkd,jkd->jk', deltas,deltas)
        ret[chunk_start:chunk_start + chunk_size] = np.argmin(dist, axis=0)
    return ret

%timeit nearest_neighbor(set_points, test_points, 1)
1 loops, best of 3: 283 ms per loop

%timeit nearest_neighbor(set_points, test_points, 10)
1 loops, best of 3: 175 ms per loop

%timeit nearest_neighbor(set_points, test_points, 100)
1 loops, best of 3: 384 ms per loop

%timeit nearest_neighbor(set_points, test_points, 1000)
1 loops, best of 3: 365 ms per loop

%timeit nearest_neighbor(set_points, test_points, 10000)
1 loops, best of 3: 374 ms per loop

因此,通过部分矢量化可以获得一些性能。