使用numpy.argpartition

时间:2016-02-04 16:16:40

标签: python sorting numpy

我正在尝试获取数组的前N个大值。

例如:

total_A = [0. 0. 30. 0. 20. 58. 0. 0. 31. 0. 0. 0. 398. 132. 0. 0. 316. 0.]

使用此:

top_A = numpy.argpartition(total_A, -18, axis=None)[-18:]

我明白了:

[(top_A[i], numpy.round(total_A[top_A[i]])) for i in top_A]

[(0, 0.0),
 (1, 0.0),
 (2, 30.0),
 (3, 0.0),
 (4, 20.0),
 (5, 58.0),
 (6, 0.0),
 (7, 0.0),
 (8, 31.0),
 (9, 0.0),
 (10, 0.0),
 (11, 0.0),
 (12, 398.0),
 (13, 132.0),
 (14, 0.0),
 (15, 0.0),
 (16, 316.0),
 (17, 0.0)]

这显然只是我的输入数组顺序。 另外,如果我试试这个:

top_A = numpy.argpartition(total_A, -10, axis=None)[-10:]

我收到:

IndexError                                Traceback (most recent call last)
<ipython-input-34-c457841cdbf1> in <module>()
     16 
     17 print 'Top categories for input matrix A:'
---> 18 [(top_A[i], numpy.round(total_A[top_A[i]])) for i in top_A]

我想知道我错过了什么。

2 个答案:

答案 0 :(得分:1)

这一行:

[(top_A[i], numpy.round(total_A[top_A[i]])) for i in top_A]

应该简化为

[(i, numpy.round(total_A[i])) for i in top_A]

top_A中的值是total_A的索引。

例如,

In [318]: a = np.array([1, 3, 5, 7, 9, 2, 4, 6, 8])

获取3个最大值的索引:

In [319]: p = np.argpartition(a, -3)[-3:]

In [320]: p
Out[320]: array([3, 8, 4])

显示3个最大值:

In [321]: a[p]
Out[321]: array([7, 8, 9])

显示包含索引和相应值的元组:

In [322]: [(i, a[i]) for i in p]
Out[322]: [(3, 7), (8, 8), (4, 9)]

答案 1 :(得分:0)

如果我明白你的意思,你需要尝试这样的事情:

A[numpy.argpartition(A, A.size - N)][-N:]

获得数组A的N值。