所以我在Keras中有一个使用Mask的自定义图层。
要使其与save / load一起使用,我需要正确序列化Mask。所以这个标准代码不起作用:
def get_config(self):
config = {'mask': self.mask}
base_config = super(Mixing, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
其中mask是对Masking Layer的引用。
我不确定如何序列化屏蔽(或一般的Keras图层)。有人可以帮忙吗?
答案 0 :(得分:1)
您可以实现与内置Wrapper
类相同的serializing methods。
def get_config(self):
config = {'layer': {'class_name': self.layer.__class__.__name__,
'config': self.layer.get_config()}}
base_config = super(Wrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
from . import deserialize as deserialize_layer
layer = deserialize_layer(config.pop('layer'),
custom_objects=custom_objects)
return cls(layer, **config)
在序列化期间,在get_config
中,内层的班级名称和配置会保存在config['layer']
中。
在from_config
中,使用deserialize_layer
使用config['layer']
对内层进行反序列化。