Keras:序列化屏蔽层以进行保存/加载

时间:2018-01-22 22:26:15

标签: python keras

所以我在Keras中有一个使用Mask的自定义图层。

要使其与save / load一起使用,我需要正确序列化Mask。所以这个标准代码不起作用:

def get_config(self):
    config =  {'mask': self.mask}
    base_config = super(Mixing, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

其中mask是对Masking Layer的引用。

我不确定如何序列化屏蔽(或一般的Keras图层)。有人可以帮忙吗?

1 个答案:

答案 0 :(得分:1)

您可以实现与内置Wrapper类相同的serializing methods

def get_config(self):
    config = {'layer': {'class_name': self.layer.__class__.__name__,
                        'config': self.layer.get_config()}}
    base_config = super(Wrapper, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
    from . import deserialize as deserialize_layer
    layer = deserialize_layer(config.pop('layer'),
                              custom_objects=custom_objects)
    return cls(layer, **config)

在序列化期间,在get_config中,内层的班级名称和配置会保存在config['layer']中。

from_config中,使用deserialize_layer使用config['layer']对内层进行反序列化。