如果嵌套数组的最大值超过阈值,则获取嵌套数组的Numpy条件

时间:2015-04-13 12:11:54

标签: python arrays numpy max

我有以下数组:

arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])

我想得到arr indices ,其中包含一个最大值大于或等于.9的数组。因此,对于这种情况,结果将是[1],因为索引为1 [.9, .1]的数组是唯一一个最大值为> = 9的数组。

我试过了:

>>> condition = np.max(arr) >= .9
>>> arr[condition]
array([ 0.5,  0.5])

但是,如你所见,它会产生错误的答案。

2 个答案:

答案 0 :(得分:2)

In [18]: arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])

In [19]: numpy.argwhere(numpy.max(arr, 1) >= 0.9)
Out[19]: array([[1]])

答案 1 :(得分:1)

您得到错误答案的原因是因为np.max(arr)为您提供了展平数组的最大值。您想要np.max(arr, axis=1),或者更好,arr.max(axis=1)

(arr.max(axis=1)>=.9).nonzero()