Python 2.6优化嵌套循环

时间:2012-01-02 19:38:04

标签: python optimization

我有一个函数,它有一个字典作为输入和一个值n。 字典中的每个项目都是具有一个或多个值的集合。 该函数应对字典键进行排序,它应提取并返回“n”值。 此功能将经常执行,因此我正在尝试优化它。有什么建议吗?

def select_items(temp_dict, n):
  """Select n items from the dictionary"""
  res = []
  sort_keys = sorted(temp_dict.keys())
  count = 0

  for key in sort_keys:
    for pair in temp_dict[key]:
      if count < n:
        res.append(pair)
        count += 1
      else:
        return res

  return res

在这段代码中,我有一个count和“if语句”来控制所选值的数量。有没有办法通过在itertools或其他东西中使用某些函数来优化这段代码?

4 个答案:

答案 0 :(得分:6)

这是我的第一次尝试(见select_items_faster),这几乎加倍了速度:

In [12]: print _11
import itertools

def select_items_original(temp_dict, n):
  """Select n items from the dictionary"""
  res = []
  sort_keys = sorted(temp_dict.keys())
  count = 0

  for key in sort_keys:
    for pair in temp_dict[key]:
      if count < n:
        res.append(pair)
        count += 1
      else:
        return res

  return res

def select_items_faster(temp_dict, n):
    """Select n items from the dictionary"""
    items = temp_dict.items()
    items.sort()

    return list(itertools.chain.from_iterable(val for (_, val) in itertools.islice(items, n)))

test_dict = dict((x, ["a"] * int(x / 500)) for x in range(1000))
test_n = 300


In [13]: %timeit select_items_original(test_dict, test_n)
1000 loops, best of 3: 293 us per loop


In [14]: %timeit select_items_faster(test_dict, test_n)
1000 loops, best of 3: 203 us per loop

itertools.islice替换[:n]并没有多大帮助:

def select_items_faster_slice(temp_dict, n):
    """Select n items from the dictionary"""
    items = temp_dict.items()
    items.sort()

    return list(itertools.chain.from_iterable(val for (_, val) in items[:n]))

In [16]: %timeit select_items_faster_slice(test_dict, test_n)
1000 loops, best of 3: 210 us per loop

sorted

也没有
In [18]: %timeit select_items_faster_sorted(test_dict, test_n)
1000 loops, best of 3: 213 us per loop


In [19]: print _17
def select_items_faster_sorted(temp_dict, n):
    """Select n items from the dictionary"""
    return list(itertools.chain.from_iterable(val for (_, val) in itertools.islice(sorted(temp_dict.items()), n)))

map__getitem__的组合要快得多:

In [22]: %timeit select_items_faster_map_getitem(test_dict, test_n)
10000 loops, best of 3: 90.7 us per loop

In [23]: print _20
def select_items_faster_map_getitem(temp_dict, n):
    """Select n items from the dictionary"""
    keys = temp_dict.keys()
    keys.sort()
    return list(itertools.chain.from_iterable(map(temp_dict.__getitem__, keys[:n])))

用一些魔法取代list(itertools.chain.from_iterable)可以加快速度:

In [28]: %timeit select_items_faster_map_getitem_list_extend(test_dict, test_n)
10000 loops, best of 3: 74.9 us per loop

In 29: print _27
def select_items_faster_map_getitem_list_extend(temp_dict, n):
    """Select n items from the dictionary"""
    keys = temp_dict.keys()
    keys.sort()
    result = []
    filter(result.extend, map(temp_dict.__getitem__, keys[:n]))
    return result

用itertools函数替换map和slice会挤出更快的速度:

In [31]: %timeit select_items_faster_map_getitem_list_extend_iterables(test_dict, test_n)
10000 loops, best of 3: 72.8 us per loop

In [32]: print _30
def select_items_faster_map_getitem_list_extend_iterables(temp_dict, n):
    """Select n items from the dictionary"""
    keys = temp_dict.keys()
    keys.sort()
    result = []
    filter(result.extend, itertools.imap(temp_dict.__getitem__, itertools.islice(keys, n)))
    return result

这和我认为的速度一样快,因为在CPython中Python函数调用相当慢,这最大限度地减少了内循环中Python函数调用的次数。

注意

  • 由于OP没有提供任何输入数据的提示,所以我不得不猜测。我可能会离开,这可能会彻底改变“快”的含义。
  • 我的每个实现都返回n - 1个项目,而不是n。

修改:使用相同的方法分析 J.F。塞巴斯蒂安的代码:

In [2]: %timeit select_items_heapq(test_dict, test_n)
1000 loops, best of 3: 572 us per loop

In [3]: print _1
from itertools import *
import heapq

def select_items_heapq(temp_dict, n):
    return list(islice(chain.from_iterable(imap(temp_dict.get, heapq.nsmallest(n, temp_dict))),n))

TokenMacGuy 的代码:

In [5]: %timeit select_items_tokenmacguy_first(test_dict, test_n)
1000 loops, best of 3: 201 us per loop

In [6]: %timeit select_items_tokenmacguy_second(test_dict, test_n)
1000 loops, best of 3: 730 us per loop

In [7]: print _4
def select_items_tokenmacguy_first(m, n):
    k, v, r = m.keys(), m.values(), range(len(m))
    r.sort(key=k.__getitem__)
    return [v[i] for i in r[:n]]

import heapq
def select_items_tokenmacguy_second(m, n):
    k, v, r = m.keys(), m.values(), range(len(m))
    smallest = heapq.nsmallest(n, r, k.__getitem__)
    for i, ind in enumerate(smallest):
        smallest[i] = v[ind]
    return smallest

答案 1 :(得分:2)

在我看来,使用列表推导并返回生成器是一种更清晰/更易读的替代方案。使用数组切片可以避免使用if子句。

def select_items(dic, n):
  return (dic[key] for key in sorted(dic.keys())[:n])

关于速度:我认为实际的sort调用可能是这里最大的瓶颈,尽管你可能不应该担心这个问题,直到你达到字典的大尺寸。在这种情况下,您应该首先考虑保持字典排序 - 您在插入时支付复杂的价格,但查找/选择很快。一个例子是sorteddict。基于树的数据结构可能是另一种选择。

关于基准。初步设置,取自David Wolever的好答案:

test_dict = dict((x, "a") for x in range(1000))
test_n = 300

您的版本:

%timeit select_items(test_dict, test_n)
1000 loops, best of 3: 334 us per loop

此版本:

%timeit select_items(test_dict, test_n)
10000 loops, best of 3: 49.1 us per loop

答案 2 :(得分:2)

from itertools import *
import heapq
islice(chain.from_iterable(imap(temp_dict.get, heapq.nsmallest(n, temp_dict))),n)

答案 3 :(得分:1)

到目前为止给出的答案不符合用户的规范。

数据是一个序列字典,所需的结果是按键排序的字典值的前n个元素的列表。

所以如果数据是:

{1: [1, 2, 3], 2: [4, 5, 6]}

然后,如果n = 5,结果应为:

[1, 2, 3, 4, 5]

鉴于此,这是一个脚本,它将原始功能与(略微)优化的新版本进行比较:

from timeit import timeit

def select_items_old(temp_dict, n):
  res = []
  sort_keys = sorted(temp_dict.keys())
  count = 0
  for key in sort_keys:
    for pair in temp_dict[key]:
      if count < n:
        res.append(pair)
        count += 1
      else:
        return res
  return res

def select_items_new(data, limit):
    count = 0
    result = []
    extend = result.extend
    for key in sorted(data.keys()):
        value = data[key]
        extend(value)
        count += len(value)
        if count >= limit:
            break
    return result[:limit]

data = {x:range(10) for x in range(1000)}

def compare(*args):
    number = 1000
    for func in args:
        name = func.__name__
        print ('test: %s(data, 12): %r' % (name, func(data, 12)))
        code = '%s(data, %d)' % (name, 300)
        duration = timeit(
            code, 'from __main__ import %s, data' % name, number=number)
        print ('time: %s: %.2f usec/pass\n' % (code, 1000000 * duration/number))

compare(select_items_old, select_items_new)

输出:

test: select_items_old(data, 12): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1]
time: select_items_old(data, 300): 163.81 usec/pass

test: select_items_new(data, 12): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1]
time: select_items_new(data, 300): 67.74 usec/pass