在Python中获取列表中较小的n个元素

时间:2008-12-08 19:10:10

标签: python algorithm sorting

我需要在Python中获得较少的n个列表。我需要这个非常快,因为它是性能的关键部分,需要重复很多次。

n通常不大于10,列表通常有大约20000个元素。每次调用该函数时,列表总是不同的。无法进行排序。

最初,我写了这个函数:

def mins(items, n):
    mins = [float('inf')]*n
    for item in items:
        for i, min in enumerate(mins):
            if item < min:
                mins.insert(i, item)
                mins.pop()
                break
    return mins

但是这个函数无法击败对整个列表进行排序的简单排序(项目)[:n]。这是我的测试:

from random import randint, random
import time

test_data = [randint(10, 50) + random() for i in range(20000)]

init = time.time()
mins = mins(test_data, 8)
print 'mins(items, n):', time.time() - init

init = time.time()
mins = sorted(test_data)[:8]
print 'sorted(items)[:n]:', time.time() - init

结果:

mins(items, n): 0.0632939338684
sorted(items)[:n]: 0.0231449604034

sorted()[:n]快三倍。我相信这是因为:

  1. insert()操作成本很高,因为Python列表不是链表。
  2. sorted()是一个优化的c函数,我的是纯python。
  3. 有没有办法打败sorted()[:n]? 我应该使用C扩展,Pyrex或Psyco还是类似的东西?

    提前感谢您的回答。

6 个答案:

答案 0 :(得分:15)

你实际上想要一个已分类的分钟序列。

mins = items[:n]
mins.sort()
for i in items[n:]:
    if i < mins[-1]: 
        mins.append(i)
        mins.sort()
        mins= mins[:n]

这样可以更快地运行很多因为你甚至没有看到分钟,除非它的值可以大于给定的项目。大约是原算法时间的十分之一。

我的戴尔零时间运行。我必须运行10次以获得可测量的运行时间。

mins(items, n): 0.297000169754
sorted(items)[:n]: 0.109999895096
mins2(items)[:n]: 0.0309998989105

使用bisect.insort而不是追加和排序可以进一步加快这一点。

答案 1 :(得分:12)

import heapq

nlesser_items = heapq.nsmallest(n, items)

这是S.Lott's algorithm的正确版本:

from bisect    import insort
from itertools import islice

def nsmallest_slott_bisect(n, iterable, insort=insort):
    it   = iter(iterable)
    mins = sorted(islice(it, n))
    for el in it:
        if el <= mins[-1]: #NOTE: equal sign is to preserve duplicates
            insort(mins, el)
            mins.pop()

    return mins

性能:

$ python -mtimeit -s "import marshal; from nsmallest import nsmallest$label as nsmallest; items = marshal.load(open('items.marshal','rb')); n = 10"\
 "nsmallest(n, items)"
nsmallest_heapq
100 loops, best of 3: 12.9 msec per loop
nsmallest_slott_list
100 loops, best of 3: 4.37 msec per loop
nsmallest_slott_bisect
100 loops, best of 3: 3.95 msec per loop

nsmallest_slott_bisectheapq的{​​{1}} 快3倍(对于n = 10,len(项目)= 20000)。 nsmallest只是稍慢一点。目前还不清楚为什么heapq的最小是如此缓慢;它的算法几乎与上面给出的相同(对于小n)。

答案 2 :(得分:3)

我喜欢埃里克森的堆积思想。我也不知道Python,但这里似乎有一个固定解决方案:heapq — Heap queue algorithm

答案 3 :(得分:2)

可能是使用bisect模块:

import bisect

def mins(items, n):
    mins = [float('inf')]*n
    for item in items:
        bisect.insort(mins, item)
        mins.pop()
    return mins

然而,这对我来说只是快一点:

mins(items, n): 0.0892250537872
sorted(items)[:n]: 0.0990262031555

使用psyco可以加快速度:

import bisect
import psyco
psyco.full()

def mins(items, n):
    mins = [float('inf')]*n
    for item in items:
        bisect.insort(mins, item)
        mins.pop()
    return mins

结果:

mins(items, n): 0.0431621074677
sorted(items)[:n]: 0.0859830379486

答案 4 :(得分:2)

如果最关心速度,那么最快的方法将是c。 Psyco有一个前期成本,但可能会很快。 我推荐Cython for python - &gt; c编译(pf Pyrex更新)。

在c中手动编码将是最好的,并允许您使用特定于您的问题域的数据结构。

但请注意:

  

“在C中编译错误的算法   可能不比右边快   Python中的算法“@ S.Lott

我想添加S.Lott的评论,以便引起注意。 Python是一种优秀的原型语言,您可以在其中制定出一种算法,以便以后将其翻译成较低级别的语言。

答案 5 :(得分:0)

为什么不在O(N)时间内调用select_n_th元素,然后用n_th元素将数组分成两部分,这应该是最快的。

PS: 如果您没有指定n个最小元素的顺序,则此O(N)算法有效 下面的链接似乎做了选择算法。 http://code.activestate.com/recipes/269554-select-the-nth-smallest-element/

假设数组没有重复的元素,代码对我有效。效率仍然取决于问题规模,如果n <10,可能O(logn * N)算法就足够了。

import random
import numpy as np
def select(data, n):
    "Find the nth rank ordered element (the least value has rank 0)."
    data = list(data)
    if not 0 <= n < len(data):
        raise ValueError('not enough elements for the given rank')
    while True:
        pivot = random.choice(data)
        pcount = 0
        under, over = [], []
        uappend, oappend = under.append, over.append
        for elem in data:
            if elem < pivot:
                uappend(elem)
            elif elem > pivot:
                oappend(elem)
            else:
                pcount += 1
        if n < len(under):
            data = under
        elif n < len(under) + pcount:
            return pivot
        else:
            data = over
            n -= len(under) + pcount


def n_lesser(data,n):
    data_nth = select(data,n)
    ind = np.where(data<data_nth)
    return data[ind]