使用矢量化查找连续元素的最大数量

时间:2014-04-29 20:56:55

标签: python numpy vectorization

作为我项目的一部分,我需要找到一个向量中是否有4个或更多个连续元素以及它们的索引。目前我正在使用以下代码:

#sample arrays:
#a1 = np.array([0, 1, 2, 3, 5])
#a2 = np.array([0, 1, 3, 4, 5, 6])
#a3 = np.array([0, 1, 3, 4, 5])
a4 = array([0, 1, 2, 4, 5, 6])

dd = np.diff(a4) #array([1, 1, 2, 1, 1])
c = 0
idx = []
for i in range(len(dd)):
    if dd[i]==1 and c<3:
        idx.append(i)
        c+=1
    elif dd[i]!=1 and c>=3:
        break
    else:
         c=0
         idx=[]

我很想知道是否可以避免for循环并且只使用numpy函数来执行此任务。

3 个答案:

答案 0 :(得分:5)

这将为您提供一个包含所有连续块长度的数组:

np.diff(np.concatenate(([-1],) + np.nonzero(np.diff(a) != 1) + ([len(a)-1],)))

一些测试:

>>> a = [1, 2, 3, 4, 5, 6, 9, 10, 11, 14, 17, 18, 19, 20, 21]
>>> np.diff(np.concatenate(([-1],) + np.nonzero(np.diff(a) != 1) +
                           ([len(a)-1],)))
array([6, 3, 1, 5], dtype=int64)

>>> a = [0, 1, 2, 4, 5, 6]
>>> np.diff(np.concatenate(([-1],) + np.nonzero(np.diff(a) != 1) +
                           ([len(a)-1],)))
array([3, 3], dtype=int64)

要检查是否有至少4个项目,请在np.any(... >= 4)中包含以上代码。


要了解这是如何工作的,我们可以从内到外计算出我的第一个例子:

>>> a = [1, 2, 3, 4, 5, 6, 9, 10, 11, 14, 17, 18, 19, 20, 21]

首先,我们计算出连续项之间的增量:

>>> np.diff(a)
array([1, 1, 1, 1, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1])

然后,我们确定delta不是1的位置,即一大块连续项目开始或结束的位置:

>>> np.diff(a) != 1
array([False, False, False, False, False,  True, False, False,  True,
        True, False, False, False, False], dtype=bool)

我们提取True s的位置:

>>> np.nonzero(np.diff(a) != 1)
(array([5, 8, 9], dtype=int64),)

上述指数标志着连续连胜的最后一项。 Python切片定义为startlast+1,因此我们可以将该数组增加1,在开头添加零,在结尾添加数组的长度,并使所有开始和结束的索引连续序列,即:

>>> np.concatenate(([0], np.nonzero(np.diff(a) != 1)[0] + 1, [len(a)]))
array([ 0,  6,  9, 10, 15], dtype=int64)

从连续索引中获取差异将为我们提供每个连续块的所需长度。因为我们所关心的只是差异,而不是在索引中添加一个,在我的原始答案中,我选择先添加-1并附加len(a)-1

>>> np.concatenate(([-1],) + np.nonzero(np.diff(a) != 1) + ([len(a)-1],))
array([-1,  5,  8,  9, 14], dtype=int64)
>>> np.diff(np.concatenate(([-1],) + np.nonzero(np.diff(a) != 1) +
                           ([len(a)-1],)))
array([6, 3, 1, 5], dtype=int64)

假设在这个数组中你确定你想要5项块的索引,这个数组位于3的位置。要恢复该块的开始和停止索引,您只需执行以下操作:

>>> np.concatenate(([0], np.nonzero(np.diff(a) != 1)[0] + 1, [len(a)]))[3:3+2]
array([10, 15], dtype=int64)
>>> a[10:15]
[17, 18, 19, 20, 21]

答案 1 :(得分:2)

是。你正确地开始了:

from numpy import array, diff, where

numbers = array([0, 1, 3, 4, 5, 5])
differences = diff(numbers)

您对连续数字感兴趣:

consecutives = differences == 1

你想要两个连续的情况。您可以将数组与其偏移量进行比较:

(consecutives[1:] & consecutives[:-1]).any()
#>>> True

要获取出现次数,请使用.sum()代替.any()

如果您想要索引,只需使用numpy.where

[offset_indexes] = where(consecutives[1:] & consecutives[:-1])
offset_indexes
#>>> array([2])

编辑:您似乎已将所需长度从3编辑为4。这会使我的代码无效,但您只需要设置

consecutives[1:] & consecutives[:-1]

consecutives[2:] & consecutives[1:-1] & consecutives[:-2]

这是一个毫无意义的通用版本:

from numpy import arange, array, diff, where

def doubling_step_shifts(shifts):
    """
    When you apply a mask of some kind of all rotations,
    often the size of the last prints will allow shifts
    larger than 1. This is a helper for that.

    A mask is assumed to exist before invocation, as
    this is typically called repeatedly on the mask or
    a copy.
    """
    # Total shift
    subtotal = 1
    step = 1

    # While the shifts won't overflow
    while subtotal + step < shifts:
        yield step
        subtotal += step
        step *= 2

    # Make up the remainder
    if shifts - subtotal > 0:
        yield shifts - subtotal

def consecutive_indexes_of_length(numbers, length):
    # Constructing "consecutives" creates a
    # minimum mask of 1, whereas this would need
    # a mask of 0, so we special-case these
    if length <= 1:
        return arange(numbers.size)

    # Mask of consecutive numbers
    consecutives = diff(numbers) == 1
    consecutives.resize(numbers.size)

    # Recursively reapply mask to cover lengths too short
    for i in doubling_step_shifts(length-1):
        consecutives[:-i] &= consecutives[i:]

    # Reextend those lengths
    for i in doubling_step_shifts(length):
        consecutives[i:] = consecutives[i:] | consecutives[:-i]

    # Give the indexes
    return where(consecutives)[0]

编辑:加快了速度(numpy.roll很慢)。

还有一些测试:

numbers = array([1, 2, 3, 4, 5, 6, 9, 10, 11, 14, 17, 18, 19, 20, 21])

consecutive_indexes_of_length(numbers, 1)
#>>> array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14])
consecutive_indexes_of_length(numbers, 2)
#>>> array([ 0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 14])
consecutive_indexes_of_length(numbers, 3)
#>>> array([ 0,  1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 14])
consecutive_indexes_of_length(numbers, 4)
#>>> array([ 0,  1,  2,  3,  4,  5, 10, 11, 12, 13, 14])
consecutive_indexes_of_length(numbers, 5)
#>>> array([ 0,  1,  2,  3,  4,  5, 10, 11, 12, 13, 14])
consecutive_indexes_of_length(numbers, 6)
#>>> array([0, 1, 2, 3, 4, 5])
consecutive_indexes_of_length(numbers, 7)
#>>> array([], dtype=int64)

为什么呢?没理由。 O(n log k)其中n是列表中元素的数量,kGROUPSIZE ,所以不要大量使用它大GROUPSIZE。但是,对于几乎所有的组大小来说都应该相当快。

编辑:现在它很快。我打赌Cython会更快,但这样很好。

该实现具有相对简单,可扩展和使用非常原始操作的优点。除了非常小的输入外,这可能不会比Cython循环快。

答案 2 :(得分:2)

以下递归解决方案怎么样? (当然只适用于1dim数组)

我发现它非常优雅。我并不是说你应该使用它,但我很开心。

import numpy as np

def is_consecutive(arr, n):
    if n <= len(arr) <= 1:
        return True
    if len(arr) < n:
        return False
    diffs1idx = np.where(np.diff(arr) == 1)[0]
    return is_consecutive(diffs1idx, n-1)

print is_consecutive([1,2], 3)  # False
print is_consecutive([1,2,3], 3)  # True
print is_consecutive([5,1,2,3], 3)  # True
print is_consecutive([4,9,1,5,7], 3)  # False
print is_consecutive([4,9,1,2,3, 7, 9], 3)  # True
print is_consecutive(np.arange(100), 100)  # True
print is_consecutive(np.append([666], np.arange(100)), 100)  # True
print is_consecutive(np.append([666], np.arange(100)), 101)  # False

(请不要问我它是如何工作的......我不明白递归...)