numpy数组:第一次出现N个小于阈值的连续值

时间:2019-08-29 14:44:36

标签: python numpy

我有一个一维numpy数组-例如,

a = np.array([1, 4, 5, 7, 1, 2, 2, 4, 10])

我想获得第一个数字的索引,对于该数字,N个后续值都低于某个值x。

在这种情况下,对于N=3x=3,我将搜索第一个数字,该数字后面的三个条目都小于3。这将是a[4]

可以简单地通过for循环遍历所有值来轻松地实现这一点,但是我想知道是否有更干净,更有效的方法来实现此目的。

3 个答案:

答案 0 :(得分:9)

方法1:

这是矢量化的NumPy方法-

def start_valid_island(a, thresh, window_size):
    m = a<thresh
    me = np.r_[False,m,False]
    idx = np.flatnonzero(me[:-1]!=me[1:])
    lens = idx[1::2]-idx[::2]
    return idx[::2][(lens >= window_size).argmax()]

样品运行-

In [44]: a
Out[44]: array([ 1,  4,  5,  7,  1,  2,  2,  4, 10])

In [45]: start_valid_island(a, thresh=3, window_size=3)
Out[45]: 4

In [46]: a[:3] = 1

In [47]: start_valid_island(a, thresh=3, window_size=3)
Out[47]: 0

方法2:

使用SciPy's binary-erosion-

from scipy.ndimage.morphology import binary_erosion

def start_valid_island_v2(a, thresh, window_size):
    m = a<thresh
    k = np.ones(window_size,dtype=bool)
    return binary_erosion(m,k,origin=-(window_size//2)).argmax()

方法3:

要完成 set ,下面是一个基于短循环并使用numba-

的效率的循环式
from numba import njit

@njit
def start_valid_island_v3(a, thresh, window_size):
    n = len(a)
    out = None
    for i in range(n-window_size+1):
        found = True
        for j in range(window_size):
            if a[i+j]>=thresh:
                found = False
                break
        if found:
            out = i
            break
    return out

时间-

In [142]: np.random.seed(0)
     ...: a = np.random.randint(0,10,(100000000))

In [145]: %timeit start_valid_island(a, thresh=3, window_size=3)
1 loop, best of 3: 810 ms per loop

In [146]: %timeit start_valid_island_v2(a, thresh=3, window_size=3)
1 loop, best of 3: 1.27 s per loop

In [147]: %timeit start_valid_island_v3(a, thresh=3, window_size=3)
1000000 loops, best of 3: 608 ns per loop

答案 1 :(得分:0)

像这样尝试,如果没有数字与条件匹配,将返回None

def func(a, n, x):
    for i, e in enumerate(a):
        nextN = a[i+1:i+n+1]
        if len(nextN) < n:
            return None
        elif all([j < x for j in nextN]):
            return e

答案 2 :(得分:0)

对于它的价值,这是在 vanilla-python 中进行的操作。

a = [1,4,5,7,1,2,2,4,10]

res = next(i for i in range(len(a)-3) if all(j<3 for j in a[i:i+3]))
print(res)  # 4

不过,大多数Numpy解决方案可能会更快。

还请注意,如果找不到解决方案,以上内容将抛出StopIteration,因此请考虑将其包装在try块中。