获取2d numpy数组的argmin索引

时间:2015-09-10 03:22:31

标签: python arrays numpy

我有一个2D numpy距离数组:

a = np.array([[2.0, 12.1, 99.2], 
              [1.0, 1.1, 1.2], 
              [1.04, 1.05, 1.5], 
              [4.1, 4.2, 0.2], 
              [10.0, 11.0, 12.0], 
              [3.9, 4.9, 4.99] 
             ])

我需要一个评估每一行的函数,并返回具有最小值的列的列索引。当然,这可以通过以下方式轻松完成:

np.argmin(a, axis=1) 

产生:

[0, 0, 0, 2, 0, 0]

但是,我有一些限制因素:

  1. argmin评估应仅考虑低于5.0的距离。如果一行中的距离都不低于5.0,则返回'-1'作为索引
  2. 为所有行返回的索引列表必须是唯一的(即,如果两行或更多行以相同的列索引结束,则与给定列索引的距离较小的行将被赋予优先级,而所有其他行必须返回不同的列索引)。我猜这会使问题成为迭代问题,因为如果其中一行被碰撞,那么它随后可能会与具有相同列索引的另一行发生冲突。
  3. 任何未分配的行都应返回“-1”
  4. 因此,最终输出应如下所示:

    [-1, 0, 1, 2, -1, -1]
    

    一个起点是:

    1. 执行argsort
    2. 为行
    3. 分配唯一列索引
    4. 从每行中删除指定的列索引
    5. 解决tie-breaks
    6. 重复步骤2-4,直到指定所有列索引
    7. 在Python中有没有直接的方法来实现这一目标?

1 个答案:

答案 0 :(得分:0)

这会循环遍历列数,我假设这些列小于行数:

def find_smallest(a):
    i = np.argmin(a, 1)
    amin = a[np.arange(len(a)), i] # faster than a.min(1)?
    toobig = amin >=5
    i[toobig] = -1
    for u, c in zip(*np.unique(i, return_counts=True)):
        #u, c are the unique values and number of occurrences in `i`
        if c < 2:
            # no repeats of this index
            continue
        mask = i==u # the values in i that match u, which has repeats
        notclosest = np.where(mask)[0].tolist() # indices of the repeats
        notclosest.pop(np.argmin(amin[mask])) # the smallest a value is not a 'repeat', remove it from the list
        i[notclosest] = -1 # and mark all the repeats as -1
    return i

注意,我使用了-1而不是np.nan,因为索引数组是int。布尔索引的任何减少都会有所帮助。我想使用np.unique(i)中的一个可选附加输出但不能。