我想在数组中形成一个包含k个最小值索引的数组:
import heapq
import numpy as np
a= np.array([[1, 3, 5, 2, 3],
[7, 6, 5, 2, 4],
[2, 0, 5, 6, 4]])
[t[0] for t in heapq.nsmallest(2,enumerate(a[1]),lambda(t):t[1])]
===[3, 4]
但这失败了:
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
Traceback (most recent call last):
File "<pyshell#19>", line 1, in <module>
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
TypeError: 'numpy.bool_' object is not iterable
答案 0 :(得分:2)
您的问题出在a.all()
:
[t[0] for t in heapq.nsmallest(2,enumerate(a.all()),lambda(t):t[1])]
这会检查数组中所有元素的真实性,即False
(因为你有0)。
如果数组与k相比不是太大,您可以使用.argsort
获取值。在这里,我将为每一行选择两个最大的位置:
print a.argsort()[:,:2]
array([[0, 3],
[3, 4],
[1, 0]])
如果你想要全局最小值的位置,请将阵列展平:
a.flatten().argsort()[:2]
如果数组非常大,您可以使用np.argpartition
获得更好的性能,这只会执行部分排序。
答案 1 :(得分:1)
您可以将numpy.ndenumerate
与堆一起使用,也可以使用David建议的部分排序:
a = np.array([[1, 3, 5, 2, 3],
[7, 6, 5, 2, 4],
[2, 0, 5, 6, 4]])
heap = [(v, k) for k,v in numpy.ndenumerate(npa)]
heapq.heapify(heap)
heapq.nsmallest(10, heap) # for k = 10
你得到:
[(0, (2, 1)),
(1, (0, 0)),
(2, (0, 3)),
(2, (1, 3)),
(2, (2, 0)),
(3, (0, 1)),
(3, (0, 4)),
(4, (1, 4)),
(4, (2, 4)),
(5, (0, 2))]