检查numpy数组是否为二进制的快速方法(仅包含0和1)

时间:2016-11-14 19:03:42

标签: python numpy

给定一个numpy数组,如果它只包含0和1,怎么能弄清楚? 有没有实施的方法?

7 个答案:

答案 0 :(得分:9)

很少有方法 -

((a==0) | (a==1)).all()
~((a!=0) & (a!=1)).any()
np.count_nonzero((a!=0) & (a!=1))==0
a.size == np.count_nonzero((a==0) | (a==1))

运行时测试 -

In [313]: a = np.random.randint(0,2,(3000,3000)) # Only 0s and 1s

In [314]: %timeit ((a==0) | (a==1)).all()
     ...: %timeit ~((a!=0) & (a!=1)).any()
     ...: %timeit np.count_nonzero((a!=0) & (a!=1))==0
     ...: %timeit a.size == np.count_nonzero((a==0) | (a==1))
     ...: 
10 loops, best of 3: 28.8 ms per loop
10 loops, best of 3: 29.3 ms per loop
10 loops, best of 3: 28.9 ms per loop
10 loops, best of 3: 28.8 ms per loop

In [315]: a = np.random.randint(0,3,(3000,3000)) # Contains 2 as well

In [316]: %timeit ((a==0) | (a==1)).all()
     ...: %timeit ~((a!=0) & (a!=1)).any()
     ...: %timeit np.count_nonzero((a!=0) & (a!=1))==0
     ...: %timeit a.size == np.count_nonzero((a==0) | (a==1))
     ...: 
10 loops, best of 3: 28 ms per loop
10 loops, best of 3: 27.5 ms per loop
10 loops, best of 3: 29.1 ms per loop
10 loops, best of 3: 28.9 ms per loop

他们的运行时间似乎具有可比性。

答案 1 :(得分:4)

看起来你可以通过以下方式实现它:

np.array_equal(a, a.astype(bool))

如果你的数组很大,它应该避免复制太多的数组(如在其他一些答案中)。因此,它应该比其他答案略快(但未经过测试)。

答案 2 :(得分:2)

如果您可以访问Numba(或者cython),您可以编写类似下面的内容,这对于捕获非二进制数组来说会明显更快,因为它会立即使计算/停止短路而不是继续全部要素:

import numpy as np
import numba as nb

@nb.njit
def check_binary(x):
    is_binary = True
    for v in np.nditer(x):
        if v.item() != 0 and v.item() != 1:
            is_binary = False
            break

    return is_binary

在没有像Numba或Cython这样的加速器的帮助下在纯python中运行它会使这种方法过于缓慢。

时序:

a = np.random.randint(0,2,(3000,3000)) # Only 0s and 1s

%timeit ((a==0) | (a==1)).all()
# 100 loops, best of 3: 15.1 ms per loop

%timeit check_binary(a)
# 100 loops, best of 3: 11.6 ms per loop

a = np.random.randint(0,3,(3000,3000)) # Contains 2 as well

%timeit ((a==0) | (a==1)).all()
# 100 loops, best of 3: 14.9 ms per loop

%timeit check_binary(a)
# 1000000 loops, best of 3: 543 ns per loop

答案 3 :(得分:1)

只有一个数据循环:

0 <= np.bitwise_or.reduce(ar) <= 1

请注意,这不适用于浮点dtype。

如果值保证为非负值,则可能会出现短路行为:

try:
    np.empty((2,), bool)[ar]
    is_binary = True
except IndexError:
    is_binary = False

此方法(总是)分配与参数形状相同的临时数组,并且似乎比第一种方法更慢地循环数据。

答案 4 :(得分:1)

numpy唯一性怎么样?

np.unique(arr)

如果为二进制,则应返回[0,1]。

答案 5 :(得分:1)

以下内容适用于仅包含数字的所有数组。

set(array).issubset({0,1}) 

答案 6 :(得分:1)

我们可以使用 np.isin()

input_array = input_array.squeeze(-1)
is_binary   = np.isin(input_array, [0,1]).all()

第一行:
squeeze 展开输入数组,因为我们不想处理 np.isin() 与多维数组的复杂性。

第二行:
np.isin() 检查输入的所有元素是否属于 0 或 1。
np.isin() 返回 [True, False, True, True..] 的列表。
然后 all() 确保列表包含所有 True。