Python / NumPy首次出现子数组

时间:2011-08-17 22:20:24

标签: python numpy arrays

在Python或NumPy中,找出第一次出现的子阵列的最佳方法是什么?

例如,我有

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]

找出b出现在哪里的最快方式(运行时间)是什么?我理解字符串这非常容易,但对于列表或numpy ndarray呢?

非常感谢!

[编辑]我更喜欢numpy解决方案,因为根据我的经验,numpy矢量化比Python列表理解要快得多。同时,大数组是巨大的,所以我不想把它转换成字符串;这将是(太长)。

11 个答案:

答案 0 :(得分:16)

我的第一个答案,但我认为这应该有用......

[x for x in xrange(len(a)) if a[x:x+len(b)] == b]

返回模式开始的索引。

答案 1 :(得分:16)

我假设你正在寻找一个特定于numpy的解决方案,而不是简单的列表理解或for循环。一种方法可能是使用rolling window技术来搜索适当大小的窗口。这是rolling_window函数:

>>> def rolling_window(a, size):
...     shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
...     strides = a.strides + (a. strides[-1],)
...     return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
... 

然后你可以做类似

的事情
>>> a = numpy.arange(10)
>>> numpy.random.shuffle(a)
>>> a
array([7, 3, 6, 8, 4, 0, 9, 2, 1, 5])
>>> rolling_window(a, 3) == [8, 4, 0]
array([[False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

要使其真正有用,您必须使用all沿轴1减少它:

>>> numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
array([False, False, False,  True, False, False, False, False], dtype=bool)

然后你可以使用它然而你使用布尔数组。获取索引的简单方法:

>>> bool_indices = numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
>>> numpy.mgrid[0:len(bool_indices)][bool_indices]
array([3])

对于列表,您可以调整其中一个rolling window迭代器以使用类似的方法。

对于非常大型数组和子数组,您可以像这样保存内存:

>>> windows = rolling_window(a, 3)
>>> sub = [8, 4, 0]
>>> hits = numpy.ones((len(a) - len(sub) + 1,), dtype=bool)
>>> for i, x in enumerate(sub):
...     hits &= numpy.in1d(windows[:,i], [x])
... 
>>> hits
array([False, False, False,  True, False, False, False, False], dtype=bool)
>>> hits.nonzero()
(array([3]),)

另一方面,这可能会更慢。没有测试,不清楚多慢?请参阅Jamie的答案,了解另一个必须检查误报的内存保留选项。我想这两种解决方案之间的速度差异在很大程度上取决于输入的性质。

答案 2 :(得分:16)

基于卷积的方法,应该比基于stride_tricks的方法更具内存效率:

def find_subsequence(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq,
                                       subseq, mode='valid') == target)[0]
    # some of the candidates entries may be false positives, double check
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    return candidates[mask]

对于非常大的数组,可能无法使用stride_tricks方法,但这个方法仍有效:

haystack = np.random.randint(1000, size=(1e6))
needle = np.random.randint(1000, size=(100,))
# Hide 10 needles in the haystack
place = np.random.randint(1e6 - 100 + 1, size=10)
for idx in place:
    haystack[idx:idx+100] = needle

In [3]: find_subsequence(haystack, needle)
Out[3]: 
array([253824, 321497, 414169, 456777, 635055, 879149, 884282, 954848,
       961100, 973481], dtype=int64)

In [4]: np.all(np.sort(place) == find_subsequence(haystack, needle))
Out[4]: True

In [5]: %timeit find_subsequence(haystack, needle)
10 loops, best of 3: 79.2 ms per loop

答案 3 :(得分:7)

你可以调用tostring()方法将数组转换为字符串,然后你可以使用快速字符串搜索。当你有许多子阵列需要检查时,这种方法可能会更快。

import numpy as np

a = np.array([1,2,3,4,5,6])
b = np.array([2,3,4])
print a.tostring().index(b.tostring())//a.itemsize

答案 4 :(得分:7)

已编辑,包括更深入的讨论,更好的代码和更多基准)


摘要

对于原始速度和效率,可以使用经典算法之一的Cython或Numba加速版本(当输入分别是Python序列或NumPy数组时)。

推荐的方法是:

  • find_kmp_cy()用于Python序列(listtuple等)
  • find_kmp_nb()用于NumPy数组

find_rk_cy()find_rk_nb()是其他有效的方法,它们具有更高的内存效率,但不能保证在线性时间内运行。

如果无法使用Cython / Numba,则find_kmp()find_rk()都是大多数用例的一个很好的全方位解决方案,尽管对于一般情况和Python序列而言,幼稚的方法,以某种形式,尤其是find_pivot(),可能会更快。对于NumPy数组,find_conv()(来自@Jaime answer)优于任何非加速的简单方法。

(下面是完整代码,分别是herethere。)


理论

这是计算机科学中的经典问题,其名称为字符串搜索或字符串匹配问题。 天真的方法基于两个嵌套循环,平均计算复杂度为O(n + m),但最差的情况是O(n m)。 多年来,已经开发出许多alternative approaches来保证更好的最坏情况下的性能。

在经典算法中,最适合通用序列(因为它们不依赖字母)的算法是:

后一种算法的效率依赖于rolling hash的计算,因此可能需要一些额外的输入知识才能获得最佳性能。 最终,它最适合于均质数据,例如数字数组。 当然,Python中一个著名的数字数组示例是NumPy数组。

备注

  • 天真的算法非常简单,可以在Python中以不同程度的运行速度适用于不同的实现。
  • 其他算法在可以通过语言技巧进行优化的方面灵活性较差。
  • Python中的显式循环可能是一个速度瓶颈,可以使用一些技巧在解释器之外执行循环。
  • Cython特别擅长加速通用Python代码的显式循环。
  • Numba特别擅长加速NumPy数组上的显式循环。
  • 这是一个很好的生成器用例,因此所有代码都将使用那些而不是常规函数。

Python序列(listtuple等)

基于朴素算法

  • find_loop()find_loop_cy()find_loop_nb()分别是纯Python,Cython和Numba JITing中仅显式循环的实现。请注意Numba版本中的forceobj=True,这是必需的,因为我们正在使用Python对象输入。
def find_loop(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_loop_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i
find_loop_nb = nb.jit(find_loop, forceobj=True)
find_loop_nb.__name__ = 'find_loop_nb'
  • find_all()在综合生成器上将内部循环替换为all()
def find_all(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if all(seq[i + j] == subseq[j] for j in range(m)):
            yield i
  • find_slice()在切片==后用直接比较[]替换了内部循环
def find_slice(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i:i + m] == subseq:
            yield i
    在切片find_mix()之后,
  • find_mix2()==用直接比较[]替换了内部循环,但是在第一个(和最后一个)字符上包括一个或两个附加短路这可能会更快,因为使用int进行切片比使用slice()进行切片要快得多。
def find_mix(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i:i + m] == subseq:
            yield i
def find_mix2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
                and seq[i:i + m] == subseq:
            yield i
  • find_pivot()find_pivot2()使用子序列的第一项,用多个.index()调用替换外部循环,同时对内部循环使用切片,最终导致额外的短路在最后一个项目上(第一个按结构匹配)。多个.index()调用被包装在index_all()生成器中(可能单独使用)。
def index_all(seq, item, start=0, stop=-1):
    try:
        n = len(seq)
        if n > 0:
            start %= n
            stop %= n
            i = start
            while True:
                i = seq.index(item, i)
                if i <= stop:
                    yield i
                    i += 1
                else:
                    return
        else:
            return
    except ValueError:
        pass


def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i:i + m] == subseq:
            yield i
def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
            yield i

基于Knuth–Morris-Pratt(KMP)算法

  • find_kmp()是该算法的纯Python实现。由于没有简单的循环,也没有可以使用slice()进行切片的地方,因此除了使用Cython之外,没有太多要做优化的工作(Numba再次需要forceobj=True会导致代码缓慢)。
def find_kmp(seq, subseq):
    n = len(seq)
    m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    j = 1
    k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    i = j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1
  • find_kmp_cy()是该算法的Cython实现,其中索引使用C int数据类型,这导致代码快得多。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_kmp_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    cdef Py_ssize_t j = 1
    cdef Py_ssize_t k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    cdef Py_ssize_t i = 0
    j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1

基于Rabin-Karp(RK)算法

  • find_rk()是一个纯Python实现,它依赖于Python的hash()进行哈希的计算(和比较)。通过简单的sum()使这种哈希值滚动。然后,通过减去前一次访问的项目{{1}的hash()的结果并将新考虑的项目{{1 }}。
seq[i - 1]
  • hash()是该算法的Cython实现,其中索引使用适当的C数据类型,这将导致更快的代码。请注意,seq[i + m - 1]会截断“基于主机位宽的返回值。”
def find_rk(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])   # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i

基准

以上功能在两个输入上求值:

  • 随机输入
find_rk_cy()
  • 天真的算法的(几乎)最差输入
hash()

%%cython -c-O3 -c-march=native -a #cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True def find_rk_cy(seq, subseq): cdef Py_ssize_t n = len(seq) cdef Py_ssize_t m = len(subseq) if seq[:m] == subseq: yield 0 cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq) # compute hash cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m]) # compute hash cdef Py_ssize_t old_item, new_item for i in range(1, n - m + 1): old_item = hash(seq[i - 1]) new_item = hash(seq[i + m - 1]) curr_hash += new_item - old_item # update hash if hash_subseq == curr_hash and seq[i:i + m] == subseq: yield i 具有固定大小(def gen_input(n, k=2): return tuple(random.randint(0, k - 1) for _ in range(n)) )。 由于有很多选择,因此已经完成了两个单独的分组,并且省略了一些变化很小且时序几乎相同的解决方案(即def gen_input_worst(n, k=-2): result = [0] * n result[k] = 1 return tuple(result) subseq)。 对于每组,两个输入都经过测试。 对于每个基准,都提供了完整的图和最快的方法。

天真随机

bm_full_naive_random bm_zoom_naive_random

最糟糕的天真

bm_full_naive_worst bm_zoom_naive_worst

其他随机

bm_full_other_random bm_zoom_other_random

其他情况最糟糕

bm_full_other_worst bm_zoom_other_worst

(完整的代码可用here。)


NumPy数组

基于朴素算法

  • 32find_mix2()find_pivot2()分别是纯Python,Cython和Numba JITing中仅显式循环的实现。前两个代码与上面相同,因此省略。 find_loop()现在享受快速的JIT编译。内部循环已写在单独的函数中,因为它随后可用于find_loop_cy()(在Numba函数内部调用Numba函数不会招致Python典型的函数调用代价)。
find_loop_nb()
  • find_loop_nb()与上面的相同,而find_rk_nb()@nb.jit def _is_equal_nb(seq, subseq, m, i): for j in range(m): if seq[i + j] != subseq[j]: return False return True @nb.jit def find_loop_nb(seq, subseq): n = len(seq) m = len(subseq) for i in range(n - m + 1): if _is_equal_nb(seq, subseq, m, i): yield i find_all()与上面的几乎相同,唯一的不同是{{1 }}现在是find_slice()的参数:find_mix()

  • find_mix2()seq[i:i + m] == subseq具有与上述相同的想法,除了现在使用np.all()代替np.all(seq[i:i + m] == subseq),并且需要将数组相等性包含在find_pivot()通话。

find_pivot2()
  • np.where()通过滚动窗口表示循环,并使用index_all()检查匹配。这使所有循环向量化,但以创建大型临时对象为代价,同时仍在很大程度上应用朴素算法。 (方法来自@senderle answer)。
np.all()
  • def find_pivot(seq, subseq): n = len(seq) m = len(subseq) if m > n: return max_i = n - m for i in np.where(seq == subseq[0])[0]: if i > max_i: return elif np.all(seq[i:i + m] == subseq): yield i def find_pivot2(seq, subseq): n = len(seq) m = len(subseq) if m > n: return max_i = n - m for i in np.where(seq == subseq[0])[0]: if i > max_i: return elif seq[i + m - 1] == subseq[m - 1] \ and np.all(seq[i:i + m] == subseq): yield i 是上述内容的一种内存效率更高的变体,其中向量化只是部分的,并且保留了一个显式循环(沿着预期的最短维度-find_rolling()的长度)。 (该方法也来自@senderle answer)。
np.all()

基于Knuth–Morris-Pratt(KMP)算法

  • def rolling_window(arr, size): shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size) strides = arr.strides + (arr.strides[-1],) return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides) def find_rolling(seq, subseq): bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1) yield from np.mgrid[0:len(bool_indices)][bool_indices] 与上述相同,而find_rolling2()是JIT的直接编译。
subseq

基于Rabin-Karp(RK)算法

  • def find_rolling2(seq, subseq): windows = rolling_window(seq, len(subseq)) hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool) for i, x in enumerate(subseq): hits &= np.in1d(windows[:, i], [x]) yield from hits.nonzero()[0] 与上述相同,只是在find_kmp()调用中再次包含了find_kmp_nb()

  • find_kmp_nb = nb.jit(find_kmp) find_kmp_nb.__name__ = 'find_kmp_nb' 是上述的Numba加速版本。使用先前定义的find_rk()来确定匹配项,而对于哈希,它使用Numba加速的seq[i:i + m] == subseq函数,其定义非常简单。

np.all()
  • find_rk_nb()使用伪Rabin-Karp方法,其中使用_is_equal_nb()乘积对初始候选对象进行哈希处理,并位于sum_hash_nb()@nb.jit def sum_hash_nb(arr): result = 0 for x in arr: result += hash(x) return result @nb.jit def find_rk_nb(seq, subseq): n = len(seq) m = len(subseq) if _is_equal_nb(seq, subseq, m, 0): yield 0 hash_subseq = sum_hash_nb(subseq) # compute hash curr_hash = sum_hash_nb(seq[:m]) # compute hash for i in range(1, n - m + 1): curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1]) # update hash if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i): yield i 与{{1 }}。这种方法是伪的,因为尽管它仍然使用散列来识别可能的候选者,但是它可能不被认为是滚动散列(它取决于find_conv()的实际实现。此外,它还需要创建一个临时数组输入的大小(方法来自@Jaime answer)。
np.dot()

基准

像以前一样,以上功能在两个输入上求值:

  • 随机输入
seq
  • 天真的算法的(几乎)最差输入
subseq

np.where()具有固定大小(np.correlate())。 此图遵循以前的方案,为方便起见,在下面进行了总结。

由于有很多选择,因此已经完成了两个单独的分组,并且省略了一些变化很小且时序几乎相同的解决方案(即def find_conv(seq, subseq): target = np.dot(subseq, subseq) candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0] check = candidates[:, np.newaxis] + np.arange(len(subseq)) mask = np.all((np.take(seq, check) == subseq), axis=-1) yield from candidates[mask] def gen_input(n, k=2): return np.random.randint(0, k, n) )。 对于每组,两个输入都经过测试。 对于每个基准,都提供了完整的图和最快的方法。

天真随机

bm_full_naive_random bm_zoom_naive_random

最糟糕的天真

bm_full_naive_worst bm_zoom_naive_worst

其他随机

bm_full_other_random bm_zoom_other_random

其他情况最糟糕

bm_full_other_worst bm_zoom_other_worst

(完整的代码可用here。)

答案 5 :(得分:2)

另一次尝试,但我确信还有更多的pythonic&amp;有效的方式......

def array_match(a, b):
    for i in xrange(0, len(a)-len(b)+1):
        if a[i:i+len(b)] == b:
            return i
    return None
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]

print array_match(a,b)
1

(第一个答案不在问题的范围内,正如cdhowie所提到的那样)

set(a) & set(b) == set(b)

答案 6 :(得分:2)

我知道这是一个相当古老的问题,但我最近必须以快速有效的方式解决这个问题,我发现最快的方法(特别是对于长阵列),我想我把它留在这里作为参考:

data = np.array([1, 2, 3, 4, 5, 6])
sequence = np.array([3, 4, 5])
data.tostring().index(sequence.tostring())//data.itemize

你必须要小心,数组和序列都有相同的dtype。

答案 7 :(得分:1)

这是一个相当直接的选择:

def first_subarray(full_array, sub_array):
    n = len(full_array)
    k = len(sub_array)
    matches = np.argwhere([np.all(full_array[start_ix:start_ix+k] == sub_array) 
                   for start_ix in range(0, n-k+1)])
    return matches[0]

然后使用原始的a,b向量得到:

a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
first_subarray(a, b)
Out[44]: 
array([1], dtype=int64)

答案 8 :(得分:0)

像这样创建一个数组(或转换)

enum

答案 9 :(得分:0)

快速比较三种建议的解决方案(随机创建矢量的平均迭代时间为100):

import time
import collections
import numpy as np


def function_1(seq, sub):
    # direct comparison
    seq = list(seq)
    sub = list(sub)
    return [i for i in range(len(seq) - len(sub)) if seq[i:i+len(sub)] == sub]

def function_2(seq, sub):
    # Jamie's solution
    target = np.dot(sub, sub)
    candidates = np.where(np.correlate(seq, sub, mode='valid') == target)[0]
    check = candidates[:, np.newaxis] + np.arange(len(sub))
    mask = np.all((np.take(seq, check) == sub), axis=-1)
    return candidates[mask]

def function_3(seq, sub):
    # HYRY solution
    return seq.tostring().index(sub.tostring())//seq.itemsize


# --- assessment time performance
N = 100

seq = np.random.choice([0, 1, 2, 3, 4, 5, 6], 3000)
sub = np.array([1, 2, 3])

tim = collections.OrderedDict()
tim.update({function_1: 0.})
tim.update({function_2: 0.})
tim.update({function_3: 0.})

for function in tim.keys():
    for _ in range(N):
        seq = np.random.choice([0, 1, 2, 3, 4], 3000)
        sub = np.array([1, 2, 3])
        start = time.time()
        function(seq, sub)
        end = time.time()
        tim[function] += end - start

timer_dict = collections.OrderedDict()
for key, val in tim.items():
    timer_dict.update({key.__name__: val / N})

print(timer_dict)

(在我的旧机器上)会导致以下结果:

OrderedDict([
('function_1', 0.0008518099784851074), 
('function_2', 8.157730102539063e-05), 
('function_3', 6.124973297119141e-06)
])

答案 10 :(得分:0)

首先,将列表转换为字符串。

a = ''.join(str(i) for i in a)
b = ''.join(str(i) for i in b)

转换为字符串后,可以使用以下字符串函数轻松找到子字符串的索引。

a.index(b)

干杯!