我有以下数组:
arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])
我想得到arr
的 indices ,其中包含一个最大值大于或等于.9的数组。因此,对于这种情况,结果将是[1]
,因为索引为1 [.9, .1]
的数组是唯一一个最大值为> = 9的数组。
我试过了:
>>> condition = np.max(arr) >= .9
>>> arr[condition]
array([ 0.5, 0.5])
但是,如你所见,它会产生错误的答案。
答案 0 :(得分:2)
In [18]: arr = numpy.array([[.5, .5], [.9, .1], [.8, .2]])
In [19]: numpy.argwhere(numpy.max(arr, 1) >= 0.9)
Out[19]: array([[1]])
答案 1 :(得分:1)
您得到错误答案的原因是因为np.max(arr)
为您提供了展平数组的最大值。您想要np.max(arr, axis=1)
,或者更好,arr.max(axis=1)
。
(arr.max(axis=1)>=.9).nonzero()