我有一个真值数组,形式为:
truth_arr = array([[ True, False, False, True],
[False, True, False, False],
[False, False, True, True],
[False, False, False, True]])
,我想获取连续有多个真项的所有真项的索引。上面的数组应该返回如下内容:
[(0, 0), (0, 3), (2, 2), (2, 3)]
(不一定是这种形式)。
答案 0 :(得分:3)
您可以屏蔽不符合条件的行,然后使用np.nonzero
:
np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True)>1)
# (array([0, 0, 2, 2]), array([0, 3, 2, 3]))
如果您确实想要索引的元组格式列表,请在之后使用np.column_stack
:
np.column_stack(
np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True)>1))
# array([[0, 0],
# [0, 3],
# [2, 2],
# [2, 3]])
或者,更Python化地
[*zip(*np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True)>1))]
# [(0, 0), (0, 3), (2, 2), (2, 3)]
答案 1 :(得分:1)
@cs95's answer的微小变化,只是为了以问题中建议的形式获得输出:
import numpy as np
truth_arr = np.array([[True, False, False, True],
[False, True, False, False],
[False, False, True, True],
[False, False, False, True]])
indices = np.nonzero(truth_arr * truth_arr.sum(axis=1, keepdims=True) > 1)
result = list(zip(*indices))
print(result)
输出:
[(0,0),(0,3),(2,2),(2,3)]
注意:
原始答案的输出是numpy
喜欢这些索引的方式,因此您可以使用:
truth_arr[indices]
获得[ True True True True]
,
truth_arr[result]
会导致错误...