TL; DR
使用可变维度键实现字典过滤功能的最有效方法是什么?过滤器应采用与字典键相同尺寸的元组,并输出字典中与过滤器匹配的所有键,以便filter[i] is None or filter[i] == key[i]
为所有维度i
。
在我目前的项目中,我需要处理包含大量数据的字典。字典的一般结构是这样的,它包含2到4个整数作为键和整数作为值的元组。字典中的所有键具有相同的尺寸。为了说明,以下是我需要处理的词典示例:
{(1, 2): 1, (1, 5): 2}
{(1, 5, 3): 2}
{(5, 2, 5, 2): 8}
这些词典包含大量条目,其中最大的条目大约有2万个条目。我经常需要过滤这些条目,但通常只查看关键元组的某些索引。理想情况下,我想要一个我可以提供过滤器元组的功能。然后该函数应返回与过滤器元组匹配的所有键。如果过滤器元组包含None
条目,那么这将匹配该索引处字典的关键元组中的任何值。
该函数应该对具有二维键的字典执行的操作的示例:
>>> dict = {(1, 2): 1, (1, 5): 2, (2, 5): 1, (3, 9): 5}
>>> my_filter_fn((1, None))
{(1, 2), (1, 5)}
>>> my_filter_fn((None, 5))
{(1, 5), (2, 5)}
>>> my_filter_fn((2, 4))
set()
>>> my_filter_fn((None, None))
{(1, 2), (1, 5), (2, 5), (3, 9)}
由于我的词典具有不同的元组维度,我尝试通过编写一个生成器表达式来解决这个问题,该表达式考虑了元组的维度:
def my_filter_fn(entries: dict, match: tuple):
return (x for x in entries.keys() if all(match[i] is None or match[i] == x[i]
for i in range(len(key))))
不幸的是,与完全手工写出条件((match[0] is None or match[0] === x[0]) and (match[1] is None or match[1] == x[1]
)相比,这是相当缓慢的;对于4维,这大约慢10倍。这对我来说是个问题,因为我需要经常进行这种过滤。
以下代码演示了性能问题。仅提供代码来说明问题并启用测试的再现。您可以跳过代码部分,结果如下。
import random
import timeit
def access_variable_length():
for key in entry_keys:
for k in (x for x in all_entries.keys() if all(key[i] is None or key[i] == x[i]
for i in range(len(key)))):
pass
def access_static_length():
for key in entry_keys:
for k in (x for x in all_entries.keys() if
(key[0] is None or x[0] == key[0])
and (key[1] is None or x[1] == key[1])
and (key[2] is None or x[2] == key[2])
and (key[3] is None or x[3] == key[3])):
pass
def get_rand_or_none(start, stop):
number = random.randint(start-1, stop)
if number == start-1:
number = None
return number
entry_keys = set()
for h in range(100):
entry_keys.add((get_rand_or_none(1, 200), get_rand_or_none(1, 10), get_rand_or_none(1, 4), get_rand_or_none(1, 7)))
all_entries = dict()
for l in range(13000):
all_entries[(random.randint(1, 200), random.randint(1, 10), random.randint(1, 4), random.randint(1, 7))] = 1
variable_time = timeit.timeit("access_variable_length()", "from __main__ import access_variable_length", number=10)
static_time = timeit.timeit("access_static_length()", "from __main__ import access_static_length", number=10)
print("variable length time: {}".format(variable_time))
print("static length time: {}".format(static_time))
结果:
variable length time: 9.625867042849316 static length time: 1.043319165662158
我希望避免创建三个不同的函数my_filter_fn2
,my_filter_fn3
和my_filter_fn4
来涵盖字典的所有可能维度,然后使用静态维度过滤。我知道对于可变尺寸的过滤总是比对固定尺寸的过滤慢,但希望它不会慢几十倍。因为我不是Python专家,所以我希望有一种聪明的方法可以重新构造我的变量维生成器表达式,从而为我提供更好的性能。
以我描述的方式过滤大字典的最有效方法是什么?
答案 0 :(得分:3)
感谢有机会考虑集合和字典中的元组。它是Python的一个非常有用和强大的角落。
Python被解释,所以如果你来自编译语言,一个好的经验法则是避免复杂的嵌套迭代。如果你正在编写复杂的循环或理解,那么总是值得怀疑是否有更好的方法。
列表下标(stuff[i]
)和range (len(stuff))
在Python中效率低下且冗长,很少需要。迭代更有效(也更自然):
for item in stuff:
do_something(item)
以下代码很快,因为它使用了Python的一些优点:comprehension,dictionaries,sets和tuple unpacking。
有迭代,但它们简单而浅薄。 整个代码中只有一个if语句,每个过滤操作只执行4次。这也有助于提高性能 - 并使代码更易于阅读。
对方法的解释......
原始数据中的每个键:
{(1, 4, 5): 1}
按位置和值编制索引:
{
(0, 1): (1, 4, 5),
(1, 4): (1, 4, 5),
(2, 5): (1, 4, 5)
}
(Python从零开始编号元素。)
将索引整理成一个由多组元组组成的大型查找字典:
{
(0, 1): {(1, 4, 5), (1, 6, 7), (1, 2), (1, 8), (1, 4, 2, 8), ...}
(0, 2): {(2, 1), (2, 2), (2, 4, 1, 8), ...}
(1, 4): {(1, 4, 5), (1, 4, 2, 8), (2, 4, 1, 8), ...}
...
}
一旦构建了这个查找(并且它非常有效地构建),过滤只是设置交集和字典查找,两者都是闪电般快速的。即使是大型字典,过滤也需要几微秒。
该方法使用arity 2,3或4(或任何其他)的元组处理数据,但arity_filtered()
仅返回与过滤元组具有相同数量成员的键。因此,这个类为您提供了将所有数据一起过滤,或者分别处理不同大小的元组的选项,在性能方面几乎没有选择。
大型随机数据集(11,500个元组)的定时结果为0.30秒构建查找,0.007秒为100次查找。
from collections import defaultdict
import random
import timeit
class TupleFilter:
def __init__(self, data):
self.data = data
self.lookup = self.build_lookup()
def build_lookup(self):
lookup = defaultdict(set)
for data_item in self.data:
for member_ref, data_key in tuple_index(data_item).items():
lookup[member_ref].add(data_key)
return lookup
def filtered(self, tuple_filter):
# initially unfiltered
results = self.all_keys()
# reduce filtered set
for position, value in enumerate(tuple_filter):
if value is not None:
match_or_empty_set = self.lookup.get((position, value), set())
results = results.intersection(match_or_empty_set)
return results
def arity_filtered(self, tuple_filter):
tf_length = len(tuple_filter)
return {match for match in self.filtered(tuple_filter) if tf_length == len(match)}
def all_keys(self):
return set(self.data.keys())
def tuple_index(item_key):
member_refs = enumerate(item_key)
return {(pos, val): item_key for pos, val in member_refs}
data = {
(1, 2): 1,
(1, 5): 2,
(1, 5, 3): 2,
(5, 2, 5, 2): 8
}
tests = {
(1, 5): 2,
(1, None, 3): 1,
(1, None): 3,
(None, 5): 2,
}
tf = TupleFilter(data)
for filter_tuple, expected_length in tests.items():
result = tf.filtered(filter_tuple)
print("Filter {0} => {1}".format(filter_tuple, result))
assert len(result) == expected_length
# same arity filtering
filter_tuple = (1, None)
print('Not arity matched: {0} => {1}'
.format(filter_tuple, tf.filtered(filter_tuple)))
print('Arity matched: {0} => {1}'
.format(filter_tuple, tf.arity_filtered(filter_tuple)))
# check unfiltered results return original data set
assert tf.filtered((None, None)) == tf.all_keys()
>>> python filter.py
Filter (1, 5) finds {(1, 5), (1, 5, 3)}
Filter (1, None, 3) finds {(1, 5, 3)}
Filter (1, None) finds {(1, 2), (1, 5), (1, 5, 3)}
Filter (None, 5) finds {(1, 5), (1, 5, 3)}
Arity filtering: note two search results only: (1, None) => {(1, 2), (1, 5)}
答案 1 :(得分:2)
我做了一些修改:
您不需要使用dict.keys
方法来遍历密钥,遍历dict
对象本身会为我们提供密钥,
创建了单独的模块,它有助于阅读和修改:
preparations.py
带有帮助程序,用于生成测试数据:
import random
left_ends = [200, 10, 4, 7]
def generate_all_entries(count):
return {tuple(random.randint(1, num)
for num in left_ends): 1
for _ in range(count)}
def generate_entry_keys(count):
return [tuple(get_rand_or_none(1, num)
for num in left_ends)
for _ in range(count)]
def get_rand_or_none(start, stop):
number = random.randint(start - 1, stop)
if number == start - 1:
number = None
return number
functions.py
用于测试功能,main.py
用于基准测试。将参数传递给函数而不是从全局范围获取它们,因此给出静态&变长版本变为
def access_static_length(all_entries, entry_keys):
for key in entry_keys:
for k in (x
for x in all_entries
if (key[0] is None or x[0] == key[0])
and (key[1] is None or x[1] == key[1])
and (key[2] is None or x[2] == key[2])
and (key[3] is None or x[3] == key[3])):
pass
def access_variable_length(all_entries, entry_keys):
for key in entry_keys:
for k in (x
for x in all_entries
if all(key[i] is None or key[i] == x[i]
for i in range(len(key)))):
pass
在timeit.repeat
而不是timeit.timeit
的结果上使用min
来获得最具代表性的结果(this answer中的更多内容),
使用步骤entries_keys
将10
元素从100
更改为10
(包括结尾),
使用步骤all_entries
将10000
元素从15000
更改为500
(包括结尾)。
但回到这一点。
我们可以通过跳过密钥中None
值的索引检查来改进过滤
def access_variable_length_with_skipping_none(all_entries, entry_keys):
for key in entry_keys:
non_none_indexes = {i
for i, value in enumerate(key)
if value is not None}
for k in (x
for x in all_entries.keys()
if all(key[i] == x[i]
for i in non_none_indexes)):
pass
下一个建议是使用numpy
:
import numpy as np
def access_variable_length_numpy(all_entries, entry_keys):
keys_array = np.array(list(all_entries))
for entry_key in entry_keys:
non_none_indexes = [i
for i, value in enumerate(entry_key)
if value is not None]
non_none_values = [value
for i, value in enumerate(entry_key)
if value is not None]
mask = keys_array[:, non_none_indexes] == non_none_values
indexes, _ = np.where(mask)
for k in map(tuple, keys_array[indexes]):
pass
main.py
的内容:
import timeit
from itertools import product
number = 5
repeat = 10
for all_entries_count, entry_keys_count in product(range(10000, 15001, 500),
range(10, 101, 10)):
print('all entries count: {}'.format(all_entries_count))
print('entry keys count: {}'.format(entry_keys_count))
preparation_part = ("from preparation import (generate_all_entries,\n"
" generate_entry_keys)\n"
"all_entries = generate_all_entries({all_entries_count})\n"
"entry_keys = generate_entry_keys({entry_keys_count})\n"
.format(all_entries_count=all_entries_count,
entry_keys_count=entry_keys_count))
static_time = min(timeit.repeat(
"access_static_length(all_entries, entry_keys)",
preparation_part + "from functions import access_static_length",
repeat=repeat,
number=number))
variable_time = min(timeit.repeat(
"access_variable_length(all_entries, entry_keys)",
preparation_part + "from functions import access_variable_length",
repeat=repeat,
number=number))
variable_time_with_skipping_none = min(timeit.repeat(
"access_variable_length_with_skipping_none(all_entries, entry_keys)",
preparation_part +
"from functions import access_variable_length_with_skipping_none",
repeat=repeat,
number=number))
variable_time_numpy = min(timeit.repeat(
"access_variable_length_numpy(all_entries, entry_keys)",
preparation_part +
"from functions import access_variable_length_numpy",
repeat=repeat,
number=number))
print("static length time: {}".format(static_time))
print("variable length time: {}".format(variable_time))
print("variable length time with skipping `None` keys: {}"
.format(variable_time_with_skipping_none))
print("variable length time with numpy: {}"
.format(variable_time_numpy))
在我的机器上使用 Python 3.6.1 给出:
all entries count: 10000
entry keys count: 10
static length time: 0.06314293399918824
variable length time: 0.5234129569980723
variable length time with skipping `None` keys: 0.2890012050011137
variable length time with numpy: 0.22945181500108447
all entries count: 10000
entry keys count: 20
static length time: 0.12795891799760284
variable length time: 1.0610534609986644
variable length time with skipping `None` keys: 0.5744297259989253
variable length time with numpy: 0.5105678180007089
all entries count: 10000
entry keys count: 30
static length time: 0.19210158399801003
variable length time: 1.6491422000035527
variable length time with skipping `None` keys: 0.8566724129996146
variable length time with numpy: 0.7363859869983571
all entries count: 10000
entry keys count: 40
static length time: 0.2561357790000329
variable length time: 2.08878050599742
variable length time with skipping `None` keys: 1.1256247100027394
variable length time with numpy: 1.0066140279996034
all entries count: 10000
entry keys count: 50
static length time: 0.32130833200062625
variable length time: 2.6166040710013476
variable length time with skipping `None` keys: 1.4147321179989376
variable length time with numpy: 1.1700750320014777
all entries count: 10000
entry keys count: 60
static length time: 0.38276188999952865
variable length time: 3.153736616997776
variable length time with skipping `None` keys: 1.7147898039984284
variable length time with numpy: 1.4533947029995034
all entries count: 10000
entry keys count: 70
...
all entries count: 15000
entry keys count: 80
static length time: 0.7141444490007416
variable length time: 6.186657476999244
variable length time with skipping `None` keys: 3.376506028998847
variable length time with numpy: 3.1577993860009883
all entries count: 15000
entry keys count: 90
static length time: 0.8115685330012639
variable length time: 7.14327938399947
variable length time with skipping `None` keys: 3.7462387939995097
variable length time with numpy: 3.6140603050007485
all entries count: 15000
entry keys count: 100
static length time: 0.8950150890013902
variable length time: 7.829741768000531
variable length time with skipping `None` keys: 4.1662235900003
variable length time with numpy: 3.914334102999419
我们可以看到numpy
版本并不像预期的那样好,而且似乎不是numpy
的错误。
如果我们删除使用tuple
将过滤后的数组记录转换为map
并离开
for k in keys_array[indexes]:
...
然后它将非常快(比静态长度版本更快),因此问题在于从numpy.ndarray
对象转换为tuple
。
过滤掉None
个输入键可以让我们获得大约50%的速度增益,所以请随意添加。
答案 2 :(得分:1)
我没有一个漂亮的答案,但这种优化往往使代码更难阅读。但是如果你只需要更快的速度,那么你可以做两件事。
首先,我们可以直接从循环内部消除重复计算。您说每个字典中的所有条目都具有相同的长度,因此您可以计算一次,而不是在循环中重复计算。这对我来说削减了大约20%:
def access_variable_length():
try:
length = len(iter(entry_keys).next())
except KeyError:
return
r = list(range(length))
for key in entry_keys:
for k in (x for x in all_entries.keys() if all(key[i] is None or key[i] == x[i]
for i in r)):
pass
不漂亮,我同意。但是我们可以通过使用eval
构建固定长度函数来使速度更快(甚至更加丑陋!)。像这样:
def access_variable_length_new():
try:
length = len(iter(entry_keys).next())
except KeyError:
return
func_l = ["(key[{0}] is None or x[{0}] == key[{0}])".format(i) for i in range(length)]
func_s = "lambda x,key: " + " and ".join(func_l)
func = eval(func_s)
for key in entry_keys:
for k in (x for x in all_entries.keys() if func(x,key)):
pass
对我来说,这几乎与静态版本一样快。
答案 3 :(得分:0)
假设你有一本字典 - d
d = {(1,2):3,(1,4):5,(2,4):2,(1,3):4,(2,3):6,(5,1):5,(3,8):5,(3,6):9}
首先你可以获得字典键 -
keys = d.keys()
=>
dict_keys([(1, 2), (3, 8), (1, 3), (2, 3), (3, 6), (5, 1), (2, 4), (1, 4)])
现在让我们定义一个函数is_match
,它可以根据您的条件决定给定的两个元组,如果它们相等或不相同 -
is_match((1,7),(1,None))
,is_match((1,5),(None,5))
和is_match((1,4),(1,4))
会在True
时返回is_match((1,7),(1,8))
,is_match((4,7),(6,12))
将返回False
。
def if_equal(a, b):
if a is None or b is None:
return True
else:
if a==b:
return True
else:
return False
is_match = lambda a,b: False not in list(map(if_equal, a, b))
tup = (1, None)
matched_keys = [key for key in keys if is_match(key, tup)]
=>
[(1, 2), (1, 3), (1, 4)]