是否有一种聪明的方法将密钥传递给defaultdict的default_factory?

时间:2010-05-26 10:55:15

标签: python

一个类有一个带有一个参数的构造函数:

class C(object):
    def __init__(self, v):
        self.v = v
        ...

在代码的某处,dict中的值可以知道它们的键 我想使用defaultdict,并将密钥传递给新生默认值:

d = defaultdict(lambda : C(here_i_wish_the_key_to_be))

有什么建议吗?

4 个答案:

答案 0 :(得分:102)

它很难成为聪明的 - 但是子类化是你的朋友:

class keydefaultdict(defaultdict):
    def __missing__(self, key):
        if self.default_factory is None:
            raise KeyError( key )
        else:
            ret = self[key] = self.default_factory(key)
            return ret

d = keydefaultdict(C)
d[x] # returns C(x)

答案 1 :(得分:19)

不,没有。

无法将defaultdict实施配置为将缺失的key传递给开箱即用的default_factory。您唯一的选择是按照@JochenRitzel的建议实现您自己的defaultdict子类。

但那并不聪明"或者几乎和标准库解决方案一样干净(如果它存在的话)。 因此,您的简洁,是/否问题的答案显然是"否"。

标准库缺少这种经常需要的工具太糟糕了。

答案 2 :(得分:6)

我认为你根本不需要defaultdict。为什么不使用dict.setdefault方法?

>>> d = {}
>>> d.setdefault('p', C('p')).v
'p'

这当然会创建C的许多实例。如果这是一个问题,我认为更简单的方法将会:

>>> d = {}
>>> if 'e' not in d: d['e'] = C('e')

就我所见,它会比defaultdict或任何其他替代方案更快。

ETA 关于in测试与使用try-except子句的速度:

>>> def g():
    d = {}
    if 'a' in d:
        return d['a']


>>> timeit.timeit(g)
0.19638929363557622
>>> def f():
    d = {}
    try:
        return d['a']
    except KeyError:
        return


>>> timeit.timeit(f)
0.6167065411074759
>>> def k():
    d = {'a': 2}
    if 'a' in d:
        return d['a']


>>> timeit.timeit(k)
0.30074866358404506
>>> def p():
    d = {'a': 2}
    try:
        return d['a']
    except KeyError:
        return


>>> timeit.timeit(p)
0.28588609450770264

答案 3 :(得分:0)

这是一个自动添加值的字典的工作示例。在/ usr / include中查找重复文件的演示任务。请注意,自定义词典 PathDict 仅需要四行:

class FullPaths:

    def __init__(self,filename):
        self.filename = filename
        self.paths = set()

    def record_path(self,path):
        self.paths.add(path)

class PathDict(dict):

    def __missing__(self, key):
        ret = self[key] = FullPaths(key)
        return ret

if __name__ == "__main__":
    pathdict = PathDict()
    for root, _, files in os.walk('/usr/include'):
        for f in files:
            path = os.path.join(root,f)
            pathdict[f].record_path(path)
    for fullpath in pathdict.values():
        if len(fullpath.paths) > 1:
            print("{} located in {}".format(fullpath.filename,','.join(fullpath.paths)))