我们用零填充numpy数组:
np.zeros((N,N+1))
但是我们如何检查给定n * n numpy数组矩阵中的所有元素是否为零 如果所有值都确实为零,则该方法只需返回True。
答案 0 :(得分:128)
此处发布的其他答案可行,但使用的最明确,最有效的功能是numpy.any()
:
>>> all_zeros = not np.any(a)
或
>>> all_zeros = not a.any()
numpy.all(a==0)
优先,因为它使用较少的RAM。 (它不需要a==0
术语创建的临时数组。)numpy.count_nonzero(a)
更快,因为它可以在找到第一个非零元素时立即返回。np.any()
不再使用"短路"逻辑,所以你不会看到小阵列的速度优势。答案 1 :(得分:55)
>>> np.count_nonzero(np.eye(4))
4
>>> np.count_nonzero([[0,1,7,0,0],[3,0,0,2,19]])
5
答案 2 :(得分:41)
我在这里使用np.all,如果你有一个数组a:
>>> np.all(a==0)
答案 3 :(得分:1)
正如另一个答案所说,如果您知道0
是数组中唯一可能存在的虚假元素,则可以利用真实/虚假评估。数组中的所有元素都是虚假的,前提是其中没有任何真实的元素。*
>>> a = np.zeros(10)
>>> not np.any(a)
True
但是,答案声称any
比其他选项要快,部分原因是短路。截至2018年,Numpy的all
和any
没有短路。
如果您经常这样做,那么使用numba
制作自己的短路版本非常容易:
import numba as nb
# short-circuiting replacement for np.any()
@nb.jit(nopython=True)
def sc_any(array):
for x in array.flat:
if x:
return True
return False
# short-circuiting replacement for np.all()
@nb.jit(nopython=True)
def sc_all(array):
for x in array.flat:
if not x:
return False
return True
即使没有短路,它们也往往比Numpy的版本更快。 count_nonzero
是最慢的。
一些输入来检查性能:
import numpy as np
n = 10**8
middle = n//2
all_0 = np.zeros(n, dtype=int)
all_1 = np.ones(n, dtype=int)
mid_0 = np.ones(n, dtype=int)
mid_1 = np.zeros(n, dtype=int)
np.put(mid_0, middle, 0)
np.put(mid_1, middle, 1)
# mid_0 = [1 1 1 ... 1 0 1 ... 1 1 1]
# mid_1 = [0 0 0 ... 0 1 0 ... 0 0 0]
检查:
## count_nonzero
%timeit np.count_nonzero(all_0)
# 220 ms ± 8.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.count_nonzero(all_1)
# 150 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
### all
# np.all
%timeit np.all(all_1)
%timeit np.all(mid_0)
%timeit np.all(all_0)
# 56.8 ms ± 3.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.4 ms ± 1.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 55.9 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# sc_all
%timeit sc_all(all_1)
%timeit sc_all(mid_0)
%timeit sc_all(all_0)
# 44.4 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.7 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 288 ns ± 6.36 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
### any
# np.any
%timeit np.any(all_0)
%timeit np.any(mid_1)
%timeit np.any(all_1)
# 60.7 ms ± 1.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 60 ms ± 287 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 57.7 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# sc_any
%timeit sc_any(all_0)
%timeit sc_any(mid_1)
%timeit sc_any(all_1)
# 41.7 ms ± 1.24 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 22.4 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 287 ns ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
*有用的all
和any
等效项:
np.all(a) == np.logical_not(np.any(np.logical_not(a)))
np.any(a) == np.logical_not(np.all(np.logical_not(a)))
not np.all(a) == np.any(np.logical_not(a))
not np.any(a) == np.all(np.logical_not(a))
答案 4 :(得分:1)
这会奏效。
def check(arr):
if np.all(arr == 0):
return True
return False
答案 5 :(得分:0)
如果 ur 数组中的所有元素都大于或等于 0。我认为使用 sum 是最快的方法。
test = np.ones((128, 128, 128))
%%timeit
not np.any(test)
>>> 1.46 ms ± 9.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
np.sum(test) == 0
>>> 646 µs ± 3.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
答案 6 :(得分:-7)
如果您正在测试所有零以避免在另一个numpy函数上发出警告,那么将该行包装在try中,除了block之外将无需在您感兴趣的操作之前进行零测试即
try: # removes output noise for empty slice
mean = np.mean(array)
except:
mean = 0