如何判断numpy布尔数组是否只包含一个“True`s”块?

时间:2013-09-13 19:25:26

标签: python numpy

如果我有一个包含布尔值的numpy数组,比如一些数学比较的输出,那么确定该数组是否只包含一个True个连续块的最佳方法是什么,例如

array([False, False, False, True, True, True, False, False, False], dtype=bool)

即。序列...,True, False, ..., True...永远不会发生的地方?

4 个答案:

答案 0 :(得分:5)

在这种情况下,

numpy.diff很有用。您可以计算diff ed数组中-1的数量。

注意,您还需要检查最后一个元素 - 如果它是True,则diff ed数组中不会有-1表示。更好的是,您可以在False之前将diff附加到数组。

import numpy as np
a = np.array([False, False, False, True, True, True, False, False, False], dtype=bool)
d = np.diff(np.asarray(a, dtype=int))
d
=> array([ 0,  0,  1,  0,  0, -1,  0,  0])
(d < 0).sum()
=> 1

最后添加False

b = np.append(a, [ False ])
d = np.diff(np.asarray(b, dtype=int))
...

现在,“序列......,真,假,......,真......永远不会发生”iff (d<0).sum() < 2

避免append操作(并使代码更加模糊)的一个技巧是:(d<0).sum() + a[-1] < 2(即,如果[-1]为True,则将其计为块)。当然,只有当a不为空时才会起作用。

答案 1 :(得分:3)

不是numpy原生方法,但您可以使用itertools.groupby将连续值块减少为单个项目,然后使用any检查只显示一个真值。由于grouped是可迭代的,所以只要找到真值,第一个any就会返回True,然后您可以继续检查迭代的剩余部分,并确保没有其他真实的值。

from itertools import groupby

def has_single_true_block(sequence):
    grouped = (k for k, g in groupby(sequence))
    has_true = any(grouped)
    has_another_true = any(grouped)
    return has_true and not has_another_true

答案 2 :(得分:1)

如果你只有一个Trues块,那意味着你在数组中有一个转换,或者你有两个转换,数组以False开头和结尾。还有一个简单的例子,整个数组都是True。所以你可以这样做:

def singleBlockTrue(array):
   if len(array) == 0:
       return False
   transitions = (array[1:] != array[:-1]).sum()
   if transitions == 0:
       return array[0]
   if transitions == 1:
       return True
   if transitions == 2:
       return not array[0]
   return False

这实际上是相同的逻辑,但代码更清晰。

def singleBlockTrue(array):
    if len(array) == 0:
        return False
    transitions = (array[1:] != array[:-1]).sum()
    transitions = transitions + array[0] + array[-1]
    return transitions == 2

与评论相关的一些时间:

In [41]: a = np.zeros(1000000, dtype=bool)

In [42]: timeit a[:-1] != a[1:]
100 loops, best of 3: 2.93 ms per loop

In [43]: timeit np.diff(a.view('uint8'))
100 loops, best of 3: 2.45 ms per loop

In [44]: timeit np.diff(a.astype('uint8'))
100 loops, best of 3: 3.41 ms per loop

In [45]: timeit np.diff(np.array(a, 'uint8'))
100 loops, best of 3: 3.42 ms per loop

答案 3 :(得分:0)

import numpy as np

def has_single_true_block(arr):
    if not len(arr):
        return False
    blocks = len(np.array_split(arr, np.where(np.diff(arr) != 0)[0] + 1))
    if blocks > 3:
        return False
    elif blocks == 3 and arr[0] and arr[-1]:
        return False
    elif blocks == 1 and not arr[0]:  # 0 True blocks
        return False
    return True

# TESTS

a1 = np.array([False, False, True, True, True, False, False], dtype=bool)
has_single_true_block(a1)  # => True

a2 = np.array([True, True, False, False], dtype=bool)
has_single_true_block(a2)  # => True

a3 = np.array([False, False, True, True], dtype=bool)
has_single_true_block(a3)  # => True

f1 = np.array([False, False, True, False, True, False, False], dtype=bool)
has_single_true_block(f1)  # => False

f2 = np.array([True, True, False, False, True, True], dtype=bool)
has_single_true_block(f2)  # => False

f3 = np.array([False, False, False], dtype=bool)
has_single_true_block(f3)  # => False