序列中的n个最大元素(需要保留重复)

时间:2011-07-12 19:06:27

标签: python algorithm sorting heap sequence

我需要在元组列表中找到n个最大的元素。以下是前3个元素的示例。

# I have a list of tuples of the form (category-1, category-2, value)
# For each category-1, ***values are already sorted descending by default***
# The list can potentially be approximately a million elements long.
lot = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4',  8), ('a', 'x5', 8), ('a', 'x6', 7),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8), 
       ('b', 'x4',  7), ('b', 'x5', 6), ('b', 'x6', 5)]

# This is what I need. 
# A list of tuple with top-3 largest values for each category-1
ans = [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), 
       ('a', 'x4', 8), ('a', 'x5', 8),
       ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]

我尝试使用heapq.nlargest。但是它只返回前3个最大的元素,并且不返回重复项。例如,

heapq.nlargest(3, [10, 10, 10, 9, 8, 8, 7, 6])
# returns
[10, 10, 10]
# I need
[10, 10, 10, 9, 8, 8]

我只能想到蛮力的做法。这就是我所拥有的并且有效。

res, prev_t, count = [lot[0]], lot[0], 1
for t in lot[1:]:
    if t[0] == prev_t[0]:
        count = count + 1 if t[2] != prev_t[2] else count
        if count <= 3:
            res.append(t)   
    else:
        count = 1
        res.append(t)
    prev_t = t

print res

关于如何实现这一点的任何其他想法?谢谢!

编辑:timeit结果显示,一百万个元素的列表显示mhyfritz's solution在蛮力的1/3时间内运行。不想让问题太长。所以在my answer中添加了更多详细信息。

6 个答案:

答案 0 :(得分:7)

我从你的lot分组的代码片段中获取它w.r.t. 类别-1 。以下应该工作:

from itertools import groupby, islice
from operator import itemgetter

ans = []
for x, g1 in groupby(lot, itemgetter(0)):
    for y, g2 in islice(groupby(g1, itemgetter(2)), 0, 3):
        ans.extend(list(g2))

print ans
# [('a', 'x1', 10), ('a', 'x2', 9), ('a', 'x3', 9), ('a', 'x4', 8), ('a', 'x5', 8),
#  ('b', 'x1', 10), ('b', 'x2', 9), ('b', 'x3', 8)]

答案 1 :(得分:2)

如果您已经按照这种方式对输入数据进行了排序,那么很可能您的解决方案比基于heapq的解决方案要好一些。

您的算法复杂度为O(n),而基于heapq的算法复杂度为O(n * log(3)),并且可能需要对数据进行更多传递才能正确排列。

答案 2 :(得分:1)

其他一些细节......我计算了使用itertools的{​​{3}}和我的代码(暴力)。

以下是timeit的{​​{1}}结果以及包含100万个元素的列表。

n = 10

如果有人好奇,这里有他的代码如何运作的痕迹。

# Here's how I built the sample list of 1 million entries.
lot = []
for i in range(1001):
    for j in reversed(range(333)):
        for k in range(3):
            lot.append((i, 'x', j))

# timeit Results for n = 10
brute_force = 6.55s
itertools = 2.07s
# clearly the itertools solution provided by mhyfritz is much faster.

答案 3 :(得分:0)

这就是这个想法,用你想要排序的值做一个dict作为键,以及将该值作为值的元组列表。

然后按键对dict的项目进行排序,从顶部获取项目,提取它们的值并加入它们。

快速,丑陋的代码:

>>> sum(
        map(lambda x: x[1],
            sorted(dict([(x[2], filter(lambda y: y[2] == x[2], lot))
                for x in lot]).items(),
                reverse=True)[:3]),
    [])

7: [('a', 'x1', 10),
 ('b', 'x1', 10),
 ('a', 'x2', 9),
 ('a', 'x3', 9),
 ('b', 'x2', 9),
 ('a', 'x4', 8),
 ('a', 'x5', 8),
 ('b', 'x3', 8)]

只是为了给你一些想法,希望它有所帮助。如果您需要澄清,请在评论中提出

答案 4 :(得分:0)

这是怎么回事?它不会完全返回您想要的结果,因为它会在y上进行反向排序。

# split lot by first element of values
lots = defaultdict(list)
for x, y, z in lot:
    lots[x].append((y, z))

ans = []
for x, l in lots.iteritems():
    # find top-3 unique values
    top = nlargest(3, set(z for (y, z) in l))
    ans += [(x, y, z) for (z, y) in sorted([(z, y) for (y, z) in l
                                                   if z in top],
                                           reverse=True)]

print ans

答案 5 :(得分:0)

from collections import *

categories = defaultdict(lambda: defaultdict(lambda: set()))
for t in myTuples:
    cat1,cat2,val = t
    categories[cat1][val].add(t)

def onlyTopThreeKeys(d):
    keys = sorted(d.keys())[-3:]
    return {k:d[k] for k in keys}

print( {cat1:onlyTopThreeKeys(sets) for cat1,sets in categories.items()} )

结果:

{'a': {8: {('a', 'x5', 8), ('a', 'x4', 8)},
       9: {('a', 'x3', 9), ('a', 'x2', 9)},
       10: {('a', 'x1', 10)}},
 'b': {8: {('b', 'x3', 8)}, 
       9: {('b', 'x2', 9)}, 
       10: {('b', 'x1', 10)}}}

平面列表:我做了上面的方法,因为它为您提供了更多信息。要获得一个平面列表,请使用闭包以onlyTopThreeKeys

发出结果
from collections import *

def topTiedThreeInEachCategory(tuples):
    categories = defaultdict(lambda: defaultdict(lambda: set()))
    for t in myTuples:
        cat1,cat2,val = t
        categories[cat1][val].add(t)

    reap = set()

    def sowTopThreeKeys(d):
        keys = sorted(d.keys())[-3:]
        for k in keys:
            for x in d[k]:
                reap.add(x)
    for sets in categories.values():
        sowTopThreeKeys(sets)

    return reap

结果:

>>> topTiedThreeInEachCategory(myTuples)
{('b', 'x2', 9), ('a', 'x1', 10), ('b', 'x3', 8), ('a', 'x2', 9), ('a', 'x4', 8), ('a', 'x3', 9), ('a', 'x5', 8), ('b', 'x1', 10)}

如果您的输入保证按照示例输入进行排序,也可以使用itertools.groupby,但如果排序发生变化,这将导致代码中断。