如何在多维数组中找到k个最小数的索引?

时间:2014-06-21 19:11:35

标签: python numpy

我想在数组中形成一个包含k个最小值索引的数组:

import heapq
import numpy as np

a= np.array([[1, 3, 5, 2, 3],
       [7, 6, 5, 2, 4],
       [2, 0, 5, 6, 4]])

[t[0] for t in heapq.nsmallest(2,enumerate(a[1]),lambda(t):t[1])]
===[3, 4]

但这失败了:

[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]

Traceback (most recent call last):
  File "<pyshell#19>", line 1, in <module>
    [t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
TypeError: 'numpy.bool_' object is not iterable

2 个答案:

答案 0 :(得分:2)

您的问题出在a.all()

[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]

这会检查数组中所有元素的真实性,即False(因为你有0)。

如果数组与k相比不是太大,您可以使用.argsort获取值。在这里,我将为每一行选择两个最大的位置:

print a.argsort()[:,:2]

array([[0, 3],
       [3, 4],
       [1, 0]])

如果你想要全局最小值的位置,请将阵列展平:

a.flatten().argsort()[:2]

如果数组非常大,您可以使用np.argpartition获得更好的性能,这只会执行部分排序。

答案 1 :(得分:1)

您可以将numpy.ndenumerate与堆一起使用,也可以使用David建议的部分排序:

a = np.array([[1, 3, 5, 2, 3],
       [7, 6, 5, 2, 4],
       [2, 0, 5, 6, 4]])
heap = [(v, k) for k,v in numpy.ndenumerate(npa)]
heapq.heapify(heap)
heapq.nsmallest(10, heap) # for k = 10

你得到:

[(0, (2, 1)),
 (1, (0, 0)),
 (2, (0, 3)),
 (2, (1, 3)),
 (2, (2, 0)),
 (3, (0, 1)),
 (3, (0, 4)),
 (4, (1, 4)),
 (4, (2, 4)),
 (5, (0, 2))]