我有一个很大的numpy
矩阵M
。矩阵的某些行的所有元素都为零,我需要获取这些行的索引。我考虑的天真方法是循环遍历矩阵中的每一行,然后检查每个元素。但是我认为使用numpy
可以更好,更快地完成此任务。我希望你能帮忙!
答案 0 :(得分:32)
这是一种方式。我假设已使用import numpy as np
导入numpy。
In [20]: a
Out[20]:
array([[0, 1, 0],
[1, 0, 1],
[0, 0, 0],
[1, 1, 0],
[0, 0, 0]])
In [21]: np.where(~a.any(axis=1))[0]
Out[21]: array([2, 4])
这个答案略有不同:How to check that a matrix contains a zero column?
这是发生了什么:
如果数组中的任何值为“truthy”,则any
方法返回True。非零数字被视为True,0被视为False。通过使用参数axis=1
,该方法将应用于每一行。对于示例a
,我们有:
In [32]: a.any(axis=1)
Out[32]: array([ True, True, False, True, False], dtype=bool)
因此每个值指示相应的行是否包含非零值。 ~
运算符是二进制“not”或补码:
In [33]: ~a.any(axis=1)
Out[33]: array([False, False, True, False, True], dtype=bool)
(给出相同结果的替代表达式为(a == 0).all(axis=1)
。)
要获取行索引,我们使用where
函数。它返回其参数为True的索引:
In [34]: np.where(~a.any(axis=1))
Out[34]: (array([2, 4]),)
请注意where
返回包含单个数组的元组。 where
适用于n维数组,因此它总是返回一个元组。我们想要那个元组中的单个数组。
In [35]: np.where(~a.any(axis=1))[0]
Out[35]: array([2, 4])
答案 1 :(得分:1)
如果元素为int(0)
,则可接受的答案有效。如果要查找所有值均为0.0(浮点数)的行,则必须使用np.isclose()
:
print(x)
# output
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.],
])
np.where(np.all(np.isclose(labels, 0), axis=1))
(array([ 0, 3]),)
注意:这也适用于PyTorch张量,当您要查找归零的multihot编码矢量时非常有用。
答案 2 :(得分:0)
使用np.sum
的解决方案,
如果您想使用阈值,则很有用
a = np.array([[1.0, 1.0, 2.99],
[0.0000054, 0.00000078, 0.00000232],
[0, 0, 0],
[1, 1, 0.0],
[0.0, 0.0, 0.0]])
print(np.where(np.sum(np.abs(a), axis=1)==0)[0])
>>[2 4]
print(np.where(np.sum(np.abs(a), axis=1)<0.0001)[0])
>>[1 2 4]
使用np.prod
检查行是否至少包含一个零元素
print(np.where(np.prod(a, axis=1)==0)[0])
>>[2 3 4]