如何在满足多个条件的numpy数组中找到索引?

时间:2014-10-03 00:04:45

标签: python numpy where

我在Python中有一个数组,如下所示:

示例:

>>> scores = numpy.asarray([[8,5,6,2], [9,4,1,4], [2,5,3,8]])
>>> scores
array([[8, 5, 6, 2],
   [9, 4, 1, 4],
   [2, 5, 3, 8]])

我想在[row, col]中找到值为的所有scores索引:

1)行中的最小值

2)大于阈值

3)最多是该行中下一个最大值的.8倍

我想尽可能高效地完成它,最好没有任何循环。我一直在努力解决这个问题,所以你能提供的任何帮助都会非常感激!

2 个答案:

答案 0 :(得分:2)

应该按照

的方式行事
In [1]: scores = np.array([[8,5,6,2], [9,4,1,4], [2,5,3,8]]); threshold = 1.1; scores
Out[1]: 
array([[8, 5, 6, 2],
       [9, 4, 1, 4],
       [2, 5, 3, 8]])

In [2]: part = np.partition(scores, 2, axis=1); part
Out[2]: 
array([[2, 5, 6, 8],
       [1, 4, 4, 9],
       [2, 3, 5, 8]])

In [3]: row_mask = (part[:,0] > threshold) & (part[:,0] <= 0.8 * part[:,1]); row_mask
Out[3]: array([ True, False,  True], dtype=bool)

In [4]: rows = row_mask.nonzero()[0]; rows
Out[4]: array([0, 2])

In [5]: cols = np.argmin(scores[row_mask], axis=1); cols
Out[5]: array([3, 0])

如果你正在寻找实际的坐标对,那么你可以zip

In [6]: coords = zip(rows, cols); coords
Out[6]: [(0, 3), (2, 0)]

或者,如果您计划查看这些元素,可以按原样使用它们:

In [7]: scores[rows, cols]
Out[7]: array([2, 2])

答案 1 :(得分:1)

我认为你很难用任何for循环(或至少执行这样的循环,但可能伪装它作为其他东西)这样做,看到作为操作如何仅依赖于行,并且您想要为每一行执行操作。它不是最有效的(并且可能取决于条件2和3的频率是多少),但这将起作用:

import heapq
threshold = 1.5
ratio = .8
scores = numpy.asarray([[8,5,6,2], [9,4,1,4], [2,5,3,8]])

found_points = []
for i,row in enumerate(scores):
    lowest,second_lowest = heapq.nsmallest(2,row)
    if lowest > threshold and lowest <= ratio*second_lowest:
        found_points.append([i,numpy.where(row == lowest)[0][0]])

你得到(例如):

found_points = [[0, 3], [2, 0]]