如果非零元素的计数超过阈值,则返回每行中的非零元素索引

时间:2019-06-25 12:54:59

标签: python numpy

我有一个真值数组,形式为:

truth_arr = array([[ True, False, False,  True],
                   [False,  True, False, False],
                   [False, False,  True,  True],
                   [False, False, False,  True]])

,我想获取连续有多个真项的所有真项的索引。上面的数组应该返回如下内容:

 [(0, 0), (0, 3), (2, 2), (2, 3)]

(不一定是这种形式)。

2 个答案:

答案 0 :(得分:3)

您可以屏蔽不符合条件的行,然后使用np.nonzero

np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True)>1)
# (array([0, 0, 2, 2]), array([0, 3, 2, 3]))

如果您确实想要索引的元组格式列表,请在之后使用np.column_stack

np.column_stack(
    np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True)>1))
# array([[0, 0],
#        [0, 3],
#        [2, 2],
#        [2, 3]])

或者,更Python化地

[*zip(*np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True)>1))]
# [(0, 0), (0, 3), (2, 2), (2, 3)]

答案 1 :(得分:1)

@cs95's answer的微小变化,只是为了以问题中建议的形式获得输出:

import numpy as np

truth_arr = np.array([[True, False, False, True],
                      [False, True, False, False],
                      [False, False, True, True],
                      [False, False, False, True]])

indices = np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True) > 1)
result = list(zip(*indices))
print(result)

输出:

  

[(0,0),(0,3),(2,2),(2,3)]

注意:

原始答案的输出是numpy喜欢这些索引的方式,因此您可以使用: truth_arr[indices]获得[ True True True True]

truth_arr[result]会导致错误...