在实数数组中查找最接近元素的最快方法

时间:2019-02-13 12:36:03

标签: python search tree binary-search-tree kdtree

对于每个元素的给定实数数组,找到比当前元素少不超过0.5的元素数,然后写入新数组。

例如:

原始数组:

[0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]

结果数组:

[0,   0,   1,   2,    3,   0,   1]

解决此问题的算法和方法是什么?

重要的是,仅在负方向上选择点的邻域,这使得无法使用KdtreeBalltree算法。

我所有的问题都是尝试使用here进行编码。

3 个答案:

答案 0 :(得分:0)

这将解决您的特定任务。

def find_nearest_element(original_array):
    result_array = []
    for e in original_array:
        result_array.append(len(original_array[(e-0.5 < original_array) & (e > original_array)]))
    return result_array

original_array = np.array([0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7])
print(find_nearest_element(original_array))

输出:

[0, 0, 1, 2, 3, 0, 1]

编辑:对于较小的数组(大约len 10000),使用掩码比使用numba的版本明显更快。对于更大的阵列,使用Numba的版本更快。因此,这取决于要处理的数组大小。

一些运行时比较(以秒为单位):

For smaller arrays(size=250):
Using Numba 0.2569999694824219
Using mask 0.0350041389465332
For bigger arrays(size=40000):
Using Numba 1.4619991779327393
Using mask 4.280000686645508

我的设备上的收支平衡点约为10000(两者都需要大约0.33秒)。

答案 1 :(得分:0)

尽管下面的方法使用简单的逻辑并且易于编写,但是速度很慢。我们可以使用修饰的Numba函数来加快速度。这样可以将简单的循环任务加快到接近汇编语言的速度。

使用pip install numba安装Numba。

from numba import jit
import numpy as np

# Create a numpy array of length 10000 with float values between 0 and 10
random_values = np.random.uniform(0.0,10.0,size=(100*100,))

@jit(nopython=True, nogil=True)
def find_nearest(input):
  result = []
  for e in input:
    counter = 0
    for j in input:
      if j >= (e-0.5) and j < e:
        counter += 1
    result.append(counter)
  return result

result = find_nearest(random_values)

请注意,将返回测试用例的预期结果:

test = [0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]
result = find_nearest(test)
print result

返回:

[0, 0, 1, 2, 3, 0, 1]

答案 2 :(得分:0)

对于有序数组,此问题很容易解决。您只需要向后搜索并计算所有大于实际数字半径的数字。如果不再满足该条件,则可以退出内部循环(这样可以节省大量时间)。

示例

import numpy as np
from scipy import spatial
import numba as nb

@nb.njit(parallel=True)
def get_counts_2(Points_sorted,ind,r):
  counts=np.zeros(Points_sorted.shape[0],dtype=np.int64)
  for i in nb.prange(0,Points_sorted.shape[0]):
    count=0
    for j in range(i-1,0,-1):
      if (Points_sorted[i]-r<Points_sorted[j]):
        count+=1
      else:
        break
    counts[ind[i]]=count
  return counts

时间

r=0.001
Points=np.random.rand(1_000_000)

t1=time.time()
ind=np.argsort(Points)
Points_sorted=Points[ind]
counts=get_counts_2(Points_sorted,ind,r)
print(time.time()-t1)
#0.29s