根据条件获取NumPy数组的连续元素组

时间:2019-07-04 12:41:56

标签: python arrays numpy grouping

我有一个NumPy数组,如下所示:

import numpy as np
a = np.array([1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8])

和常数b = 6

基于previous question,我可以计算c的数量,该数量由a中的元素连续少于b的次数小于等于2的次数定义

from itertools import groupby
b = 6
sum(len(list(g))>=2 for i, g in groupby(a < b) if i)

所以在此示例中,c == 3

现在我想在每次满足条件时输出一个数组,而不是计算满足条件的次数。

因此,在此示例中,正确的输出为:

array1 = [1, 4, 2]
array2 = [4, 4]
array3 = [3, 4, 4, 5]

因为:

1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8  # numbers in a
1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0  # (a<b)
^^^^^^^-----^^^^-----------------------------^^^^^^^^^^---  # (a<b) 2+ times consecutively
   1         2                                    3

到目前为止,我已经尝试了不同的选择:

np.isin((len(list(g))>=2 for i, g in groupby(a < b)if i), a)

np.extract((len(list(g))>=2 for i, g in groupby(a < b)if i), a)

但是他们都没有实现我所寻找的目标。有人可以指出我正确的Python工具,以便输出满足我条件的不同数组吗?

3 个答案:

答案 0 :(得分:2)

在测量my other answer的性能时,我注意到虽然它比Austin's solution快(对于长度<15000的数组),但是它的复杂度不是线性的。

基于this answer,我想到了使用np.split的以下解决方案,该解决方案比此处先前添加的两个答案都更有效:

array = np.append(a, -np.inf)  # padding so we don't lose last element
mask = array >= 6  # values to be removed
split_indices = np.where(mask)[0]
for subarray in np.split(array, split_indices + 1):
    if len(subarray) > 2:
        print(subarray[:-1])

给予:

[1. 4. 2.]
[4. 4.]
[3. 4. 4. 5.]

性能*:

enter image description here

*由perfplot

测量

答案 1 :(得分:1)

使用groupby并抓住组:

from itertools import groupby

lst = []
b = 6
for i, g in groupby(a, key=lambda x: x < b):
    grp = list(g)
    if i and len(grp) >= 2:
        lst.append(grp)

print(lst)

# [[1, 4, 2], [4, 4], [3, 4, 4, 5]]

答案 2 :(得分:1)

此任务与image labeling非常相似,但就您而言,它是一维的。 SciPy库提供了一些有用的图像处理功能,我们可以在此处使用

import numpy as np
from scipy.ndimage import (binary_dilation,
                           binary_erosion,
                           label)

a = np.array([1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8])
b = 6  # your threshold
min_consequent_count = 2

mask = a < b
structure = [False] + [True] * min_consequent_count  # used for erosion and dilation
eroded = binary_erosion(mask, structure)
dilated = binary_dilation(eroded, structure)
labeled_array, labels_count = label(dilated)  # labels_count == c

for label_number in range(1, labels_count + 1):  # labeling starts from 1
    subarray = a[labeled_array == label_number]
    print(subarray)

给予:

[1 4 2]
[4 4]
[3 4 4 5]

说明:

  1. mask = a < b返回具有True值的boolean array,其中元素小于阈值b

    array([ True,  True,  True, False,  True,  True, False,  True, False,
           False,  True, False, False,  True, False,  True,  True,  True,
            True, False])
    
  2. 如您所见,结果包含一些True元素,它们周围没有任何其他True邻居。为了消除它们,我们可以使用binary erosion。我为此使用scipy.ndimage.binary_erosion。它的默认structure参数不适合我们的需求,因为它还会删除两个随后的True值,因此我构造了自己的值:

    >>> structure = [False] + [True] * min_consequent_count
    >>> structure
    [False, True, True]
    >>> eroded = binary_erosion(mask, structure)
    >>> eroded
    array([ True,  True, False, False,  True, False, False, False, False,
           False, False, False, False, False, False,  True,  True,  True,
           False, False])
    
  3. 我们设法删除了单个True值,但是我们需要获取其他组的初始配置。为此,我们将binary dilation与相同的structure一起使用:

    >>> dilated = binary_dilation(eroded, structure)
    >>> dilated
    array([ True,  True,  True, False,  True,  True, False, False, False,
           False, False, False, False, False, False,  True,  True,  True,
            True, False])
    

    binary_dilation的文档:link

  4. 最后一步,我们用scipy.ndimage.label标记每个组:

    >>> labeled_array, labels_count = label(dilated)
    >>> labeled_array
    array([1, 1, 1, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 0])
    >>> labels_count
    3
    

    您可以看到labels_countc值-问题中的组数相同。 在这里,您可以通过布尔索引简单地获取子组:

    >>> a[labeled_array == 1]
    array([1, 4, 2])
    >>> a[labeled_array == 3]
    array([3, 4, 4, 5])