numpy中是否有一种方法可以验证一个数组是否包含在另一个数组中?

时间:2019-11-08 11:00:35

标签: python numpy

我想验证一个numpy数组在另一个数组中是否是连续序列。

例如

a = np.array([1,2,3,4,5,6,7])
b = np.array([3,4,5])
c = np.array([2,3,4,6])

预期结果将是:

is_sequence_of(b, a) # should return True
is_sequence_of(c, a) # should return False

我想知道是否有一个执行此操作的numpy方法。

1 个答案:

答案 0 :(得分:4)

方法1

我们可以将np.searchsorted与-p一起使用

def isin_seq(a,b):
    # Look for the presence of b in a, while keeping the sequence
    sidx = a.argsort()
    idx = np.searchsorted(a,b,sorter=sidx)
    idx[idx==len(a)] = 0
    ssidx = sidx[idx]
    return (np.diff(ssidx)==1).all() & (a[ssidx]==b).all()

请注意,这假设输入数组没有重复项。

样品运行-

In [42]: isin_seq(a,b) # search for the sequence b in a
Out[42]: True

In [43]: isin_seq(c,b) # search for the sequence b in c
Out[43]: False

方法2

另一个与skimage.util.view_as_windows-

from skimage.util import view_as_windows

def isin_seq_v2(a,b):
    return (view_as_windows(a,len(b))==b).all(1).any()

方法3

这也可以视为模板匹配问题,因此,对于整数,我们可以将OpenCV的内置函数用于template-matchingcv2.matchTemplate(受this post启发) ,就像这样-

import cv2 
from cv2 import matchTemplate as cv2m

def isin_seq_v3(arr,seq):
    S = cv2m(arr.astype('uint8'),seq.astype('uint8'),cv2.TM_SQDIFF)
    return np.isclose(S,0).any()

方法4

我们的方法可以受益于基于short-circuiting的方法。因此,我们将numbafrom numba import njit @njit def isin_seq_numba(a,b): m = len(a) n = len(b) for i in range(m-n+1): for j in range(n): if a[i+j]!=b[j]: break if j==n-1: return True return False 一起使用,以提高性能-

Tue Oct 30 12:57:49 +0000 2012