使用ruamel.yaml安全地转储和加载defaultdict

时间:2018-10-23 15:20:47

标签: python yaml defaultdict ruamel.yaml

我正在尝试在Python中将具有collections.defaultdict属性的类与ruamel.yaml进行反序列化(在我的例子中为3.6+)。

这是我要开始工作的一个最小示例:

from collections import defaultdict
import ruamel.yaml
from pathlib import Path

class Foo:
    def __init__(self):
        self.x = defaultdict()


YAML = ruamel.yaml.YAML(typ="safe")
YAML.register_class(Foo)
YAML.register_class(defaultdict)

fp =  Path("./test.yaml")
YAML.dump(Foo(), fp)
YAML.load(fp)

但这失败了:

AttributeError: 'collections.defaultdict' object has no attribute '__dict__'

是否有不需要为每个“ Foo-like”类编写自定义代码的想法?我希望可以为defaultdict对象添加一个不同的表示符,但是到目前为止,我的尝试一直没有成功。

完整追溯:

Traceback (most recent call last):
File "./tests/test_yaml.py", line 18, in <module>
    YAML.dump(Foo(), fp)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\main.py", line 439, in dump
    return self.dump_all([data], stream, _kw, transform=transform)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\main.py", line 453, in dump_all
    self._context_manager.dump(data)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\main.py", line 801, in dump
    self._yaml.representer.represent(data)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\representer.py", line 81, in represent
    node = self.represent_data(data)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\representer.py", line 108, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\main.py", line 638, in t_y
    tag, data, cls, flow_style=representer.default_flow_style
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\representer.py", line 384, in represent_yaml_object
    return self.represent_mapping(tag, state, flow_style=flow_style)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\representer.py", line 218, in represent_mapping
    node_value = self.represent_data(item_value)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\representer.py", line 108, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\main.py", line 638, in t_y
    tag, data, cls, flow_style=representer.default_flow_style
File "C:\miniconda-windows\envs\ratio\lib\site-packages\ruamel\yaml\representer.py", line 383, in represent_yaml_object
    state = data.__dict__.copy()
AttributeError: 'collections.defaultdict' object has no attribute '__dict__'

2 个答案:

答案 0 :(得分:2)

defaultdict无法表示的问题可能相对 通过将represent_yaml_object方法扩展为 退回到尝试在发生AttributeError时直接复制数据。

您可以通过对SafeRepresenter时使用的typ='safe'进行细分来实现。 并通过将其分配给您的Representer实例的YAML属性来进行注册。 由于该方法很小,因此您最好使用内部(state) 替换整个方法。

一旦这样做,您会发现您无法加载结果并执行相同的操作 SafeConstructor

from collections import defaultdict
import ruamel.yaml
from pathlib import Path

class MyRepresenter(ruamel.yaml.SafeRepresenter):
    def represent_yaml_object(self, tag, data, cls, flow_style=None):
        if hasattr(data, '__getstate__'):
            state = data.__getstate__()
        else:
            try:
                state = data.__dict__.copy()
            except AttributeError:
                state = data.copy()
        return self.represent_mapping(tag, state, flow_style=flow_style)


class MyConstructor(ruamel.yaml.SafeConstructor):
    def construct_yaml_object(self, node, cls):
        data = cls.__new__(cls)
        yield data
        if hasattr(data, '__setstate__'):
            state = self.construct_mapping(node, deep=True)
            data.__setstate__(state)
        else:
            state = self.construct_mapping(node)
            try:
                data.__dict__.update(state)
            except AttributeError:
                data.update(state)

yaml = ruamel.yaml.YAML(typ="safe")
yaml.Representer = MyRepresenter
yaml.Constructor = MyConstructor

@yaml.register_class
class Foo:
    def __init__(self):
        self.x = defaultdict()

yaml.register_class(defaultdict)

fp =  Path("./test.yaml")
yaml.dump(Foo(), fp)
d = yaml.load(fp)

print(fp.read_text(), end='')

给出:

!Foo
x: !defaultdict {}

我认为让变量全大写不是一个好主意,特别是 使用与该类相同的名称。我还展示了如何使用装饰器进行注册(在 如果您不知道那是可能的。


上面有一个警告,那就是如果您的 defaultdict有一个default_factory参数( 默认值None),则会在翻译中丢失。

为了支持这样的defaultdict,需要将其作为标记序列转储:

!defaultdict
- !list
- a: [1]
  b: [2]

或作为标记的映射:

!defaultdict
default: !list
values: 
  a: [1]
  b: [2]

带标记的序列比IMO更好,因为如果default_factory 属性为“无”(即,在调用defaultdict()时) 简化版:

!defaultdict
a: "some assigned value"
b: [42]

这样的defaultdict可能只有defaultvalues键, 如果第二种(标记映射)形式是 使用。

要实现标记的序列,您需要用于的表示符和构造函数 defaultdict。您还需要两个都可以用于工厂功能 您提供的。使用defaultdict(list)可以做到:

from collections import defaultdict
from pathlib import Path
import ruamel.yaml

class MyRepresenter(ruamel.yaml.SafeRepresenter):
    def represent_defaultdict(self, data):
        if data.default_factory is None:
            return self.represent_mapping(u'!defaultdict', data)
        d = [data.default_factory, dict(data)]
        return self.represent_sequence(u'!defaultdict', d)

    def represent_listclass(self, data):
        return self.represent_scalar(u'!list', "")

MyRepresenter.add_representer(defaultdict, MyRepresenter.represent_defaultdict)
MyRepresenter.add_representer(type(list), MyRepresenter.represent_listclass)

class MyConstructor(ruamel.yaml.SafeConstructor):
    def construct_defaultdict(self, node):
        data = defaultdict()  
        yield data
        if isinstance(node, ruamel.yaml.nodes.SequenceNode):
            data.default_factory, keyvals = self.construct_sequence(node, deep=True)
            data.update(keyvals)
        elif isinstance(node, ruamel.yaml.nodes.MappingNode):
            value = self.construct_mapping(node, deep=True)
            data.update(value)

    def construct_listclass(self, node):
        return list

MyConstructor.add_constructor(u'!defaultdict', MyConstructor.construct_defaultdict)
MyConstructor.add_constructor(u'!list', MyConstructor.construct_listclass)


yaml = ruamel.yaml.YAML(typ="safe")
yaml.Representer = MyRepresenter
yaml.Constructor = MyConstructor

@yaml.register_class
class Foo:
    def __init__(self, df=None):
        self.x = defaultdict(df)

d0 = Foo(df=list)
d0.x['a'].append(1)
d0.x['b'].append(2)
d1 = Foo()
d1.x['a'] = "some assigned value"
d1.x['b'] = [42]
d = [d0, d1]

fp =  Path("./test.yaml")
yaml.dump(d, fp)
print(fp.read_text(), end='')
d = yaml.load(fp)

d[0].x['c'].append(3)
print('----')
print(dict(d[0].x))

导致:

- !Foo
  x: !defaultdict
  - !list 
  - a: [1]
    b: [2]
- !Foo
  x: !defaultdict
    a: some assigned value
    b: [42]
----
{'a': [1], 'b': [2], 'c': [3]}

如您所见,重新加载的d[0].x的{​​{1}}为list

答案 1 :(得分:1)

这是因为defaultdict是内置类dict的子类,该内置类没有__dict__属性,YAML编码器无法生成类属性名。在这种情况下,defaultdict应该被视为dict,但是问题在于ruamel.yaml.representer.BaseRepresenter类的represent_data方法仅查看对象本身的类来确定如果该对象有一个表示符:

data_types = type(data).__mro__
# ...skipped
if data_types[0] in self.yaml_representers:
    node = self.yaml_representers[data_types[0]](self, data)

相反,应该做的是检查__mro__中的任何数据类型是否都具有表示符,并在找到时使用它:

if any(data_type in self.yaml_representers for data_type in data_types):
    node = self.yaml_representers[next(data_type for data_type in data_types if data_type in self.yaml_representers)](self, data)

因此,我们可以自己对该方法进行猴子修补:

def represent_data(self, data):
    # type: (Any) -> Any
    if self.ignore_aliases(data):
        self.alias_key = None
    else:
        self.alias_key = id(data)
    if self.alias_key is not None:
        if self.alias_key in self.represented_objects:
            node = self.represented_objects[self.alias_key]
            # if node is None:
            #     raise RepresenterError(
            #          "recursive objects are not allowed: %r" % data)
            return node
        # self.represented_objects[alias_key] = None
        self.object_keeper.append(data)
    data_types = type(data).__mro__
    if representer.PY2:
        # if type(data) is types.InstanceType:
        if isinstance(data, representer.types.InstanceType):
            data_types = representer.get_classobj_bases(data.__class__) + list(data_types)
    if any(data_type in self.yaml_representers for data_type in data_types):
        node = self.yaml_representers[next(data_type for data_type in data_types if data_type in self.yaml_representers)](self, data)
    else:
        for data_type in data_types:
            if data_type in self.yaml_multi_representers:
                node = self.yaml_multi_representers[data_type](self, data)
                break
        else:
            if None in self.yaml_multi_representers:
                node = self.yaml_multi_representers[None](self, data)
            elif None in self.yaml_representers:
                node = self.yaml_representers[None](self, data)
            else:
                node = representer.ScalarNode(None, representer.text_type(data))
    # if alias_key is not None:
    #     self.represented_objects[alias_key] = node
    return node
representer.BaseRepresenter.represent_data = represent_data

以便您的代码无需注册defaultdict即可工作:

class Foo:
    def __init__(self):
        self.x = defaultdict()

YAML = ruamel.yaml.YAML(typ="safe")
YAML.register_class(Foo)
# YAML.register_class(defaultdict)
fp =  Path("/temp/test.yaml")
YAML.dump(Foo(), fp)
YAML.load(fp)

编辑:一种更为优雅的解决方案是简单地将SafeRepresenter.represent_dict方法添加为defaultdict的表示形式:

from ruamel.yaml import representer
representer.SafeRepresenter.add_representer(defaultdict, representer.SafeRepresenter.represent_dict)