在numpy数组中找到第n个最小元素

时间:2014-03-20 22:07:49

标签: python numpy

我需要找到1D numpy.array中最小的第n个元素。

例如:

a = np.array([90,10,30,40,80,70,20,50,60,0])

我想获得第5个最小元素,所以我想要的输出是40

我目前的解决方案是:

result = np.max(np.partition(a, 5)[:5])

然而,找到5个最小的元素然后拿出最大的元素对我来说似乎不太笨拙。有没有更好的方法呢?我错过了一个能实现目标的功能吗?

有一些问题与此类似的标题,但我没有看到任何回答我的问题。

修改

我原本应该提到它,但性能对我来说非常重要;因此,heapq解决方案虽然不错但对我不起作用。

import numpy as np
import heapq

def find_nth_smallest_old_way(a, n):
    return np.max(np.partition(a, n)[:n])

# Solution suggested by Jaime and HYRY    
def find_nth_smallest_proper_way(a, n):
    return np.partition(a, n-1)[n-1]

def find_nth_smallest_heapq(a, n):
    return heapq.nsmallest(n, a)[-1]
#    
n_iterations = 10000

a = np.arange(1000)
np.random.shuffle(a)

t1 = timeit('find_nth_smallest_old_way(a, 100)', 'from __main__ import find_nth_smallest_old_way, a', number = n_iterations)
print 'time taken using partition old_way: {}'.format(t1)    
t2 = timeit('find_nth_smallest_proper_way(a, 100)', 'from __main__ import find_nth_smallest_proper_way, a', number = n_iterations)
print 'time taken using partition proper way: {}'.format(t2) 
t3 = timeit('find_nth_smallest_heapq(a, 100)', 'from __main__ import find_nth_smallest_heapq, a', number = n_iterations)  
print 'time taken using heapq : {}'.format(t3)

结果:

time taken using partition old_way: 0.255564928055
time taken using partition proper way: 0.129678010941
time taken using heapq : 7.81094002724

3 个答案:

答案 0 :(得分:22)

除非我遗漏了什么,否则你想要做的是:

>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> np.partition(a, 4)[4]
40

np.partition(a, k)会将k的{​​{1}}个最小元素放在a,将较小的值放在a[k]中,将较大的值放在a[:k]中。唯一需要注意的是,由于0索引,第五个元素位于索引4。

答案 1 :(得分:4)

您可以使用heapq.nsmallest

>>> import numpy as np
>>> import heapq
>>> 
>>> a = np.array([90,10,30,40,80,70,20,50,60,0])
>>> heapq.nsmallest(5, a)[-1]
40

答案 2 :(得分:0)

您不需要致电numpy.max()

def nsmall(a, n):
    return np.partition(a, n)[n]