对于每个元素的给定实数数组,找到比当前元素少不超过0.5
的元素数,然后写入新数组。
例如:
原始数组:
[0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]
结果数组:
[0, 0, 1, 2, 3, 0, 1]
解决此问题的算法和方法是什么?
重要的是,仅在负方向上选择点的邻域,这使得无法使用Kdtree
或Balltree
算法。
我所有的问题都是尝试使用here进行编码。
答案 0 :(得分:0)
这将解决您的特定任务。
def find_nearest_element(original_array):
result_array = []
for e in original_array:
result_array.append(len(original_array[(e-0.5 < original_array) & (e > original_array)]))
return result_array
original_array = np.array([0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7])
print(find_nearest_element(original_array))
输出:
[0, 0, 1, 2, 3, 0, 1]
编辑:对于较小的数组(大约len 10000),使用掩码比使用numba的版本明显更快。对于更大的阵列,使用Numba的版本更快。因此,这取决于要处理的数组大小。
一些运行时比较(以秒为单位):
For smaller arrays(size=250):
Using Numba 0.2569999694824219
Using mask 0.0350041389465332
For bigger arrays(size=40000):
Using Numba 1.4619991779327393
Using mask 4.280000686645508
我的设备上的收支平衡点约为10000(两者都需要大约0.33秒)。
答案 1 :(得分:0)
尽管下面的方法使用简单的逻辑并且易于编写,但是速度很慢。我们可以使用修饰的Numba函数来加快速度。这样可以将简单的循环任务加快到接近汇编语言的速度。
使用pip install numba
安装Numba。
from numba import jit
import numpy as np
# Create a numpy array of length 10000 with float values between 0 and 10
random_values = np.random.uniform(0.0,10.0,size=(100*100,))
@jit(nopython=True, nogil=True)
def find_nearest(input):
result = []
for e in input:
counter = 0
for j in input:
if j >= (e-0.5) and j < e:
counter += 1
result.append(counter)
return result
result = find_nearest(random_values)
请注意,将返回测试用例的预期结果:
test = [0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]
result = find_nearest(test)
print result
返回:
[0, 0, 1, 2, 3, 0, 1]
答案 2 :(得分:0)
对于有序数组,此问题很容易解决。您只需要向后搜索并计算所有大于实际数字半径的数字。如果不再满足该条件,则可以退出内部循环(这样可以节省大量时间)。
示例
import numpy as np
from scipy import spatial
import numba as nb
@nb.njit(parallel=True)
def get_counts_2(Points_sorted,ind,r):
counts=np.zeros(Points_sorted.shape[0],dtype=np.int64)
for i in nb.prange(0,Points_sorted.shape[0]):
count=0
for j in range(i-1,0,-1):
if (Points_sorted[i]-r<Points_sorted[j]):
count+=1
else:
break
counts[ind[i]]=count
return counts
时间
r=0.001
Points=np.random.rand(1_000_000)
t1=time.time()
ind=np.argsort(Points)
Points_sorted=Points[ind]
counts=get_counts_2(Points_sorted,ind,r)
print(time.time()-t1)
#0.29s