numpy:测试数组项是否高于某个值,x连续多少次?

时间:2018-11-09 15:32:34

标签: python numpy

我正在尝试测试数组中的值是否连续多次超过某个值。

例如

arr1 = np.array([1,2,1,3,4,5,6,7])
arr2 = np.array([1,2,1,3,4,2,6,7])

说我想测试一下连续四个周期中数组中的一项是否为>=3。测试将为true返回arr1,但为false返回arr2

3 个答案:

答案 0 :(得分:5)

这是convolution的一种方式-

def cross_thresh_convolve(arr, thresh, N):
    # Detect if arr crosses thresh for N consecutive times anywhere
    return (np.convolve(arr>=thresh,np.ones(N,dtype=int))==N).any()

或者用binary-dilation-

from scipy.ndimage.morphology import binary_erosion

def cross_thresh_erosion(arr, thresh, N):
    return binary_erosion(arr>=thresh, np.ones(N)).any()

样品运行-

In [43]: arr1 = np.array([1,2,1,3,4,5,6,7])
    ...: arr2 = np.array([1,2,1,3,4,2,6,7])

In [44]: print cross_thresh_convolve(arr1, thresh=3, N=4)
    ...: print cross_thresh_erosion(arr1, thresh=3, N=4)
    ...: print cross_thresh_convolve(arr2, thresh=3, N=4)
    ...: print cross_thresh_erosion(arr2, thresh=3, N=4)
True
True
False
False

常规比较

要涵盖一般比较,例如,如果我们要查找greaterless-than,或者只是简单地将相等性与值进行比较,则可以使用NumPy内置的比较函数来替换{{1} }从早期解决方案中分离出来,因此为自己提供了通用实现,例如-

arr>=thresh

因此,我们的具体案例将是-

def consecutive_comp_convolve(arr, comp, N, comparison=np.greater_equal):
    return (np.convolve(comparison(arr,comp),np.ones(N,dtype=int))==N).any()

def consecutive_comp_erosion(arr, comp, N, comparison=np.greater_equal):
    return binary_erosion(comparison(arr,comp), np.ones(N)).any()

答案 1 :(得分:1)

这是一种技术含量低但速度快的方法。构造布尔数组,形成cumsum()并将每个元素与n个位置进行比较。如果差异为n,则必须是True s的条纹。

def check_streak(a, th, n):
    ps = (a>=th).cumsum()
    return (ps[n:]-ps[:ps.size-n] == n).any()

答案 2 :(得分:0)

另一种解决方案(但比其他解决方案要慢)

import numpy as np
from numpy.lib.stride_tricks import as_strided

def f(arr, threshold=3, n=4):
    arr = as_strided(arr, shape=(arr.shape[0]-n+1, n), strides=2*arr.strides)
    return (arr >= threshold).all(axis=1).any()


# How it works:
# arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])
# n = 4
# threshold = 3

# arr = as_strided(arr, shape=(arr.shape[0]-n+1, n), strides=2*arr.strides)
# print(arr)
# [[1 2 3 4]
#  [2 3 4 5]
#  [3 4 5 6]
#  [4 5 6 7]
#  [5 6 7 8]]

# print(arr >= threshold)
# [[False False  True  True]
#  [False  True  True  True]
#  [ True  True  True  True]
#  [ True  True  True  True]
#  [ True  True  True  True]]

# print((arr >= threshold).all(axis=1))
# [False False  True  True  True]

# print((arr >= threshold).all(axis=1).any())
# True