如何有效地过滤具有任意长度元组的字典作为键?

时间:2017-06-01 08:11:10

标签: python performance dictionary filtering

TL; DR

使用可变维度键实现字典过滤功能的最有效方法是什么?过滤器应采用与字典键相同尺寸的元组,并输出字典中与过滤器匹配的所有键,以便filter[i] is None or filter[i] == key[i]为所有维度i

在我目前的项目中,我需要处理包含大量数据的字典。字典的一般结构是这样的,它包含2到4个整数作为键和整数作为值的元组。字典中的所有键具有相同的尺寸。为了说明,以下是我需要处理的词典示例:

{(1, 2): 1, (1, 5): 2}
{(1, 5, 3): 2}
{(5, 2, 5, 2): 8}

这些词典包含大量条目,其中最大的条目大约有2万个条目。我经常需要过滤这些条目,但通常只查看关键元组的某些索引。理想情况下,我想要一个我可以提供过滤器元组的功能。然后该函数应返回与过滤器元组匹配的所有键。如果过滤器元组包含None条目,那么这将匹配该索引处字典的关键元组中的任何值。

该函数应该对具有二维键的字典执行的操作的示例:

>>> dict = {(1, 2): 1, (1, 5): 2, (2, 5): 1, (3, 9): 5}
>>> my_filter_fn((1, None))
{(1, 2), (1, 5)}
>>> my_filter_fn((None, 5))
{(1, 5), (2, 5)}
>>> my_filter_fn((2, 4))
set()
>>> my_filter_fn((None, None))
{(1, 2), (1, 5), (2, 5), (3, 9)}

由于我的词典具有不同的元组维度,我尝试通过编写一个生成器表达式来解决这个问题,该表达式考虑了元组的维度:

def my_filter_fn(entries: dict, match: tuple):
    return (x for x in entries.keys() if all(match[i] is None or match[i] == x[i]
                                             for i in range(len(key))))

不幸的是,与完全手工写出条件((match[0] is None or match[0] === x[0]) and (match[1] is None or match[1] == x[1])相比,这是相当缓慢的;对于4维,这大约慢10倍。这对我来说是个问题,因为我需要经常进行这种过滤。

以下代码演示了性能问题。仅提供代码来说明问题并启用测试的再现。您可以跳过代码部分,结果如下。

import random
import timeit


def access_variable_length():
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if all(key[i] is None or key[i] == x[i]
                                                       for i in range(len(key)))):
            pass


def access_static_length():
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if
                  (key[0] is None or x[0] == key[0])
                  and (key[1] is None or x[1] == key[1])
                  and (key[2] is None or x[2] == key[2])
                  and (key[3] is None or x[3] == key[3])):
            pass


def get_rand_or_none(start, stop):
    number = random.randint(start-1, stop)
    if number == start-1:
        number = None
    return number


entry_keys = set()
for h in range(100):
    entry_keys.add((get_rand_or_none(1, 200), get_rand_or_none(1, 10), get_rand_or_none(1, 4), get_rand_or_none(1, 7)))
all_entries = dict()
for l in range(13000):
    all_entries[(random.randint(1, 200), random.randint(1, 10), random.randint(1, 4), random.randint(1, 7))] = 1

variable_time = timeit.timeit("access_variable_length()", "from __main__ import access_variable_length", number=10)
static_time = timeit.timeit("access_static_length()", "from __main__ import access_static_length", number=10)

print("variable length time: {}".format(variable_time))
print("static length time: {}".format(static_time))

结果:

variable length time: 9.625867042849316
static length time: 1.043319165662158

我希望避免创建三个不同的函数my_filter_fn2my_filter_fn3my_filter_fn4来涵盖字典的所有可能维度,然后使用静态维度过滤。我知道对于可变尺寸的过滤总是比对固定尺寸的过滤慢,但希望它不会慢几十倍。因为我不是Python专家,所以我希望有一种聪明的方法可以重新构造我的变量维生成器表达式,从而为我提供更好的性能。

以我描述的方式过滤大字典的最有效方法是什么?

4 个答案:

答案 0 :(得分:3)

感谢有机会考虑集合和字典中的元组。它是Python的一个非常有用和强大的角落。

Python被解释,所以如果你来自编译语言,一个好的经验法则是避免复杂的嵌套迭代。如果你正在编写复杂的循环或理解,那么总是值得怀疑是否有更好的方法。

列表下标(stuff[i])和range (len(stuff))在Python中效率低下且冗长,很少需要。迭代更有效(也更自然):

for item in stuff:
    do_something(item)

以下代码很快,因为它使用了Python的一些优点:comprehension,dictionaries,sets和tuple unpacking。

有迭代,但它们简单而浅薄。  整个代码中只有一个if语句,每个过滤操作只执行4次。这也有助于提高性能 - 并使代码更易于阅读。

对方法的解释......

原始数据中的每个键:

{(1, 4, 5): 1}

按位置和值编制索引:

{
    (0, 1): (1, 4, 5),
    (1, 4): (1, 4, 5),
    (2, 5): (1, 4, 5)
}

(Python从零开始编号元素。)

将索引整理成一个由多组元组组成的大型查找字典:

{
    (0, 1): {(1, 4, 5), (1, 6, 7), (1, 2), (1, 8), (1, 4, 2, 8), ...}
    (0, 2): {(2, 1), (2, 2), (2, 4, 1, 8), ...}
    (1, 4): {(1, 4, 5), (1, 4, 2, 8), (2, 4, 1, 8), ...}
    ...
}

一旦构建了这个查找(并且它非常有效地构建),过滤只是设置交集和字典查找,两者都是闪电般快速的。即使是大型字典,过滤也需要几微秒。

该方法使用arity 2,3或4(或任何其他)的元组处理数据,但arity_filtered()仅返回与过滤元组具有相同数量成员的键。因此,这个类为您提供了将所有数据一起过滤,或者分别处理不同大小的元组的选项,在性能方面几乎没有选择。

大型随机数据集(11,500个元组)的定时结果为0.30秒构建查找,0.007秒为100次查找。

from collections import defaultdict
import random
import timeit


class TupleFilter:
    def __init__(self, data):
        self.data = data
        self.lookup = self.build_lookup()

    def build_lookup(self):
        lookup = defaultdict(set)
        for data_item in self.data:
            for member_ref, data_key in tuple_index(data_item).items():
                lookup[member_ref].add(data_key)
        return lookup

    def filtered(self, tuple_filter):
        # initially unfiltered
        results = self.all_keys()
        # reduce filtered set
        for position, value in enumerate(tuple_filter):
            if value is not None:
                match_or_empty_set = self.lookup.get((position, value), set())
                results = results.intersection(match_or_empty_set)
        return results

    def arity_filtered(self, tuple_filter):
        tf_length = len(tuple_filter)
        return {match for match in self.filtered(tuple_filter) if tf_length == len(match)}

    def all_keys(self):
        return set(self.data.keys())


def tuple_index(item_key):
    member_refs = enumerate(item_key)
    return {(pos, val): item_key for pos, val in member_refs}


data = {
    (1, 2): 1,
    (1, 5): 2,
    (1, 5, 3): 2,
    (5, 2, 5, 2): 8
}

tests = {
     (1, 5): 2,
     (1, None, 3): 1,
     (1, None): 3,
     (None, 5): 2,
}

tf = TupleFilter(data)
for filter_tuple, expected_length in tests.items():
    result = tf.filtered(filter_tuple)
    print("Filter {0} => {1}".format(filter_tuple, result))
    assert len(result) == expected_length
# same arity filtering
filter_tuple = (1, None)
print('Not arity matched: {0} => {1}'
      .format(filter_tuple, tf.filtered(filter_tuple)))
print('Arity matched: {0} => {1}'
      .format(filter_tuple, tf.arity_filtered(filter_tuple)))
# check unfiltered results return original data set
assert tf.filtered((None, None)) == tf.all_keys()


>>> python filter.py
Filter (1, 5) finds {(1, 5), (1, 5, 3)}
Filter (1, None, 3) finds {(1, 5, 3)}
Filter (1, None) finds {(1, 2), (1, 5), (1, 5, 3)}
Filter (None, 5) finds {(1, 5), (1, 5, 3)}
Arity filtering: note two search results only: (1, None) => {(1, 2), (1, 5)}

答案 1 :(得分:2)

我做了一些修改:

  • 您不需要使用dict.keys方法来遍历密钥,遍历dict对象本身会为我们提供密钥,

  • 创建了单独的模块,它有助于阅读和修改:

    • preparations.py带有帮助程序,用于生成测试数据:

      import random
      
      left_ends = [200, 10, 4, 7]
      
      
      def generate_all_entries(count):
          return {tuple(random.randint(1, num)
                        for num in left_ends): 1
                  for _ in range(count)}
      
      
      def generate_entry_keys(count):
          return [tuple(get_rand_or_none(1, num)
                        for num in left_ends)
                  for _ in range(count)]
      
      
      def get_rand_or_none(start, stop):
          number = random.randint(start - 1, stop)
          if number == start - 1:
              number = None
          return number
      
    • functions.py用于测试功能,
    • main.py用于基准测试。
  • 将参数传递给函数而不是从全局范围获取它们,因此给出静态&变长版本变为

    def access_static_length(all_entries, entry_keys):
        for key in entry_keys:
            for k in (x
                      for x in all_entries
                      if (key[0] is None or x[0] == key[0])
                      and (key[1] is None or x[1] == key[1])
                      and (key[2] is None or x[2] == key[2])
                      and (key[3] is None or x[3] == key[3])):
                pass
    
    
    def access_variable_length(all_entries, entry_keys):
        for key in entry_keys:
            for k in (x
                      for x in all_entries
                      if all(key[i] is None or key[i] == x[i]
                             for i in range(len(key)))):
                pass
    
  • timeit.repeat而不是timeit.timeit的结果上使用min来获得最具代表性的结果(this answer中的更多内容),

  • 使用步骤entries_keys10元素从100更改为10(包括结尾),

  • 使用步骤all_entries10000元素从15000更改为500(包括结尾)。

但回到这一点。

的改进

  1. 我们可以通过跳过密钥中None值的索引检查来改进过滤

    def access_variable_length_with_skipping_none(all_entries, entry_keys):
        for key in entry_keys:
            non_none_indexes = {i
                                for i, value in enumerate(key)
                                if value is not None}
            for k in (x
                      for x in all_entries.keys()
                      if all(key[i] == x[i]
                             for i in non_none_indexes)):
                pass
    
  2. 下一个建议是使用numpy

    import numpy as np
    
    
    def access_variable_length_numpy(all_entries, entry_keys):
        keys_array = np.array(list(all_entries))
        for entry_key in entry_keys:
            non_none_indexes = [i
                                for i, value in enumerate(entry_key)
                                if value is not None]
            non_none_values = [value
                               for i, value in enumerate(entry_key)
                               if value is not None]
            mask = keys_array[:, non_none_indexes] == non_none_values
            indexes, _ = np.where(mask)
            for k in map(tuple, keys_array[indexes]):
                pass
    
  3. 基准

    main.py的内容:

    import timeit
    from itertools import product
    
    number = 5
    repeat = 10
    for all_entries_count, entry_keys_count in product(range(10000, 15001, 500),
                                                       range(10, 101, 10)):
        print('all entries count: {}'.format(all_entries_count))
        print('entry keys count: {}'.format(entry_keys_count))
        preparation_part = ("from preparation import (generate_all_entries,\n"
                            "                         generate_entry_keys)\n"
                            "all_entries = generate_all_entries({all_entries_count})\n"
                            "entry_keys = generate_entry_keys({entry_keys_count})\n"
                            .format(all_entries_count=all_entries_count,
                                    entry_keys_count=entry_keys_count))
        static_time = min(timeit.repeat(
            "access_static_length(all_entries, entry_keys)",
            preparation_part + "from functions import access_static_length",
            repeat=repeat,
            number=number))
        variable_time = min(timeit.repeat(
            "access_variable_length(all_entries, entry_keys)",
            preparation_part + "from functions import access_variable_length",
            repeat=repeat,
            number=number))
        variable_time_with_skipping_none = min(timeit.repeat(
            "access_variable_length_with_skipping_none(all_entries, entry_keys)",
            preparation_part +
            "from functions import access_variable_length_with_skipping_none",
            repeat=repeat,
            number=number))
        variable_time_numpy = min(timeit.repeat(
            "access_variable_length_numpy(all_entries, entry_keys)",
            preparation_part +
            "from functions import access_variable_length_numpy",
            repeat=repeat,
            number=number))
    
        print("static length time: {}".format(static_time))
        print("variable length time: {}".format(variable_time))
        print("variable length time with skipping `None` keys: {}"
              .format(variable_time_with_skipping_none))
        print("variable length time with numpy: {}"
              .format(variable_time_numpy))
    

    在我的机器上使用 Python 3.6.1 给出:

    all entries count: 10000
    entry keys count: 10
    static length time: 0.06314293399918824
    variable length time: 0.5234129569980723
    variable length time with skipping `None` keys: 0.2890012050011137
    variable length time with numpy: 0.22945181500108447
    all entries count: 10000
    entry keys count: 20
    static length time: 0.12795891799760284
    variable length time: 1.0610534609986644
    variable length time with skipping `None` keys: 0.5744297259989253
    variable length time with numpy: 0.5105678180007089
    all entries count: 10000
    entry keys count: 30
    static length time: 0.19210158399801003
    variable length time: 1.6491422000035527
    variable length time with skipping `None` keys: 0.8566724129996146
    variable length time with numpy: 0.7363859869983571
    all entries count: 10000
    entry keys count: 40
    static length time: 0.2561357790000329
    variable length time: 2.08878050599742
    variable length time with skipping `None` keys: 1.1256247100027394
    variable length time with numpy: 1.0066140279996034
    all entries count: 10000
    entry keys count: 50
    static length time: 0.32130833200062625
    variable length time: 2.6166040710013476
    variable length time with skipping `None` keys: 1.4147321179989376
    variable length time with numpy: 1.1700750320014777
    all entries count: 10000
    entry keys count: 60
    static length time: 0.38276188999952865
    variable length time: 3.153736616997776
    variable length time with skipping `None` keys: 1.7147898039984284
    variable length time with numpy: 1.4533947029995034
    all entries count: 10000
    entry keys count: 70
    ...
    all entries count: 15000
    entry keys count: 80
    static length time: 0.7141444490007416
    variable length time: 6.186657476999244
    variable length time with skipping `None` keys: 3.376506028998847
    variable length time with numpy: 3.1577993860009883
    all entries count: 15000
    entry keys count: 90
    static length time: 0.8115685330012639
    variable length time: 7.14327938399947
    variable length time with skipping `None` keys: 3.7462387939995097
    variable length time with numpy: 3.6140603050007485
    all entries count: 15000
    entry keys count: 100
    static length time: 0.8950150890013902
    variable length time: 7.829741768000531
    variable length time with skipping `None` keys: 4.1662235900003
    variable length time with numpy: 3.914334102999419
    

    恢复

    我们可以看到numpy版本并不像预期的那样好,而且似乎不是numpy的错误。

    如果我们删除使用tuple将过滤后的数组记录转换为map并离开

    for k in keys_array[indexes]:
        ...
    

    然后它将非常快(比静态长度版本更快),因此问题在于从numpy.ndarray对象转换为tuple

    过滤掉None个输入键可以让我们获得大约50%的速度增益,所以请随意添加。

答案 2 :(得分:1)

我没有一个漂亮的答案,但这种优化往往使代码更难阅读。但是如果你只需要更快的速度,那么你可以做两件事。

首先,我们可以直接从循环内部消除重复计算。您说每个字典中的所有条目都具有相同的长度,因此您可以计算一次,而不是在循环中重复计算。这对我来说削减了大约20%:

def access_variable_length():
    try:
        length = len(iter(entry_keys).next())
    except KeyError:
        return
    r = list(range(length))
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if all(key[i] is None or key[i] == x[i]
                                                       for i in r)):
            pass

不漂亮,我同意。但是我们可以通过使用eval构建固定长度函数来使速度更快(甚至更加丑陋!)。像这样:

def access_variable_length_new():
    try:
        length = len(iter(entry_keys).next())
    except KeyError:
        return
    func_l = ["(key[{0}] is None or x[{0}] == key[{0}])".format(i) for i in range(length)]
    func_s = "lambda x,key: " + " and ".join(func_l)
    func = eval(func_s)
    for key in entry_keys:
        for k in (x for x in all_entries.keys() if func(x,key)):
            pass

对我来说,这几乎与静态版本一样快。

答案 3 :(得分:0)

假设你有一本字典 - d

d = {(1,2):3,(1,4):5,(2,4):2,(1,3):4,(2,3):6,(5,1):5,(3,8):5,(3,6):9}

首先你可以获得字典键 -

keys = d.keys()
=>
dict_keys([(1, 2), (3, 8), (1, 3), (2, 3), (3, 6), (5, 1), (2, 4), (1, 4)])

现在让我们定义一个函数is_match,它可以根据您的条件决定给定的两个元组,如果它们相等或不相同 -
is_match((1,7),(1,None))is_match((1,5),(None,5))is_match((1,4),(1,4))会在True时返回is_match((1,7),(1,8))is_match((4,7),(6,12))将返回False

def if_equal(a, b):
    if a is None or b is None:
        return True
    else:
        if a==b:
            return True
        else:
            return False

is_match = lambda a,b: False not in list(map(if_equal, a, b))

tup = (1, None)
matched_keys = [key for key in keys if is_match(key, tup)]
=>
[(1, 2), (1, 3), (1, 4)]