查找numpy数组的子组

时间:2014-07-14 15:23:19

标签: numpy

我有一个像这样的numpy数组:

A = ([249,    250,   3016,   3017,   5679,   5680,   8257,   8258,
    10756,  10757,  13178,  13179,  15531,  15532,  17824,  17825,
    20058,  20059,  22239,  22240,  24373,  24374,  26455,  26456,
    28491,  28492,  30493,  30494,  32452,  32453,  34377,  34378,
    36264,  36265,  38118,  38119,  39939,  39940,  41736,  41737,
    43501,  43502,  45237,  45238,  46950,  46951,  48637,  48638]) 

我想编写一个小脚本,找到数组的值的子组,其差值小于某个阈值,比方说3,并返回子组的最高值。在A数组的情况下,输出应为:

A_out =([250,3017,5680,8258,10757,13179,...])

那是否有一个numpy函数?

2 个答案:

答案 0 :(得分:1)

这是一个矢量化的Numpy方法。

首先,数据(在一个numpy数组中)和阈值:

In [41]: A = np.array([249,    250,   3016,   3017,   5679,   5680,   8257,   8258,
    10756,  10757,  13178,  13179,  15531,  15532,  17824,  17825,
    20058,  20059,  22239,  22240,  24373,  24374,  26455,  26456,
    28491,  28492,  30493,  30494,  32452,  32453,  34377,  34378,
    36264,  36265,  38118,  38119,  39939,  39940,  41736,  41737,
    43501,  43502,  45237,  45238,  46950,  46951,  48637,  48638])

In [42]: threshold = 3

以下产生数组delta。它与delta = np.diff(A)几乎相同,但我希望在delta的末尾添加一个大于阈值的值。

In [43]: delta = np.hstack((diff(A), threshold + 1))

现在,群组最大值只是A[delta > threshold]

In [46]: A[delta > threshold]
Out[46]: 
array([  250,  3017,  5680,  8258, 10757, 13179, 15532, 17825, 20059,
       22240, 24374, 26456, 28492, 30494, 32453, 34378, 36265, 38119,
       39940, 41737, 43502, 45238, 46951, 48638])

或者,如果您愿意,A[delta >= threshold]。这给出了这个例子的相同结果:

In [47]: A[delta >= threshold]
Out[47]: 
array([  250,  3017,  5680,  8258, 10757, 13179, 15532, 17825, 20059,
       22240, 24374, 26456, 28492, 30494, 32453, 34378, 36265, 38119,
       39940, 41737, 43502, 45238, 46951, 48638])

有一种情况,这个答案与@ DrV的答案不同。从您的描述中,我不清楚如何处理诸如1, 2, 3, 4, 5, 6之类的一组值。连续差异均为1,但第一个和最后一个之间的差异为5.上面的numpy计算将这些视为一个组。 @ DrV的回答将创建两个小组。

答案 1 :(得分:-1)

解释1:组中项目的值与组中第一项的值不得超过3个单位

这是NumPy功能达到极限的原因之一。由于您将不得不遍历列表,我建议使用纯Python方法:

first_of_group = A[0]
previous = A[0]
group_lasts = []
for a in A[1:]:
    # if this item no longer belongs to the group
    if abs(a - first_of_group) > 3:
        group_lasts.append(previous)
        first_of_group = a
    previous = a
# add the last item separately, because it is always a last of the group
group_lasts.append(a)

现在你的小组持续时间为group_lasts

此处使用任何NumPy数组功能似乎没有提供太多帮助。

解释2:组中项目的值与前一项目的差异不得超过3个单位

这更容易,因为我们可以轻松地在Warren Weckesser的答案中形成分组符休息列表。 NumPy在这里提供了很多帮助。