我有一个numpy数组,带有一堆单调递增的值。说,
a = [1,2,3,4,6,10,10,11,14]
a_arr=np.array(a)
也说
thresh = 4
我想创建一个数组,其中包含a_arr
子集的索引,该子集遍历数组,选择元素但忽略的元素与最后一个选择的间距至少为thresh
。使用算法可以更容易地描述这一点:
def select_idx(a, thresh):
ret = []
for idx, elt in enumerate(a):
if len(ret) == 0 or elt >= a[ret[-1]] + thresh:
ret.append(idx)
return ret
显然我可以使用这个函数来做到这一点,但这似乎很慢。有什么方法可以在numpy中进行矢量化吗?
感谢。
P.S。在此示例中,select_idx(a,thresh)= [0,4,5,8]
编辑:此问题的近似版本可能更容易进行向量化:将数字行划分为大小为thresh
的存储桶,我想从a中的第一个值开始。因此,此示例中的桶分隔符将为0,4,8,12,16 ....选择作为其存储桶中第一个元素的数字的索引。 (是的,我意识到这与我之前写的不一样。)
答案 0 :(得分:0)
这是针对您的近似问题的矢量化解决方案:
idx = np.cumsum(np.bincount((a-a[0])/thresh))[:-1]
这将为您提供除第一个零之外的所有索引,它始终存在。这是解释:
(a-a[0])/thresh
执行整数除法(假设a
具有整数dtype),以将值分组到thresh
组中。
cumsum(bincount(...))
计算每个组的大小并将其转换为索引。请注意,如果存储区bincount
中没有值将报告0,则此数组中可能会有重复。
最后,我们丢弃最后一个索引,该索引对应于a
的大小。或者,如果索引的顺序无关紧要,您可以利用它来获得零索引:
idx = np.cumsum(np.bincount((a-a[0])/thresh)) % len(a)