我正在尝试测试数组中的值是否连续多次超过某个值。
例如
arr1 = np.array([1,2,1,3,4,5,6,7])
arr2 = np.array([1,2,1,3,4,2,6,7])
说我想测试一下连续四个周期中数组中的一项是否为>=3
。测试将为true
返回arr1
,但为false
返回arr2
。
答案 0 :(得分:5)
这是convolution
的一种方式-
def cross_thresh_convolve(arr, thresh, N):
# Detect if arr crosses thresh for N consecutive times anywhere
return (np.convolve(arr>=thresh,np.ones(N,dtype=int))==N).any()
或者用binary-dilation
-
from scipy.ndimage.morphology import binary_erosion
def cross_thresh_erosion(arr, thresh, N):
return binary_erosion(arr>=thresh, np.ones(N)).any()
样品运行-
In [43]: arr1 = np.array([1,2,1,3,4,5,6,7])
...: arr2 = np.array([1,2,1,3,4,2,6,7])
In [44]: print cross_thresh_convolve(arr1, thresh=3, N=4)
...: print cross_thresh_erosion(arr1, thresh=3, N=4)
...: print cross_thresh_convolve(arr2, thresh=3, N=4)
...: print cross_thresh_erosion(arr2, thresh=3, N=4)
True
True
False
False
常规比较
要涵盖一般比较,例如,如果我们要查找greater
或less-than
,或者只是简单地将相等性与值进行比较,则可以使用NumPy内置的比较函数来替换{{1} }从早期解决方案中分离出来,因此为自己提供了通用实现,例如-
arr>=thresh
因此,我们的具体案例将是-
def consecutive_comp_convolve(arr, comp, N, comparison=np.greater_equal):
return (np.convolve(comparison(arr,comp),np.ones(N,dtype=int))==N).any()
def consecutive_comp_erosion(arr, comp, N, comparison=np.greater_equal):
return binary_erosion(comparison(arr,comp), np.ones(N)).any()
答案 1 :(得分:1)
这是一种技术含量低但速度快的方法。构造布尔数组,形成cumsum()并将每个元素与n个位置进行比较。如果差异为n,则必须是True
s的条纹。
def check_streak(a, th, n):
ps = (a>=th).cumsum()
return (ps[n:]-ps[:ps.size-n] == n).any()
答案 2 :(得分:0)
另一种解决方案(但比其他解决方案要慢)
import numpy as np
from numpy.lib.stride_tricks import as_strided
def f(arr, threshold=3, n=4):
arr = as_strided(arr, shape=(arr.shape[0]-n+1, n), strides=2*arr.strides)
return (arr >= threshold).all(axis=1).any()
# How it works:
# arr = np.array([1, 2, 3, 4, 5, 6, 7, 8])
# n = 4
# threshold = 3
# arr = as_strided(arr, shape=(arr.shape[0]-n+1, n), strides=2*arr.strides)
# print(arr)
# [[1 2 3 4]
# [2 3 4 5]
# [3 4 5 6]
# [4 5 6 7]
# [5 6 7 8]]
# print(arr >= threshold)
# [[False False True True]
# [False True True True]
# [ True True True True]
# [ True True True True]
# [ True True True True]]
# print((arr >= threshold).all(axis=1))
# [False False True True True]
# print((arr >= threshold).all(axis=1).any())
# True