我们建立了一个自定义层,该层引用了另一层的形状张量。是否有规范的方法来序列化此属性?自定义图层如下所示:
class myCustomLayer(Layer):
def __init__(self, reference=False, reference_name=False, **kwargs):
self.reference = reference
self.reference_name = reference_name
super(myCustomLayer, self).__init__(**kwargs)
def build(self, input_shape):
super(myCustomLayer, self).build(input_shape)
def call(self, x):
if (self.reference_name != False) and (self.reference == False):
self.reference = tf.get_default_graph().get_tensor_by_name(self.reference_name)
reference_shape = K.shape(self.reference)
# do stuff with x in context of reference_shape
def get_config(self):
if self.reference != False:
reference_name = self.reference.name
config = {'target_shape': self.target_shape,
'reference_name': reference_name}
base_config = super(myCustomLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
您将创建myCustomLayer的实例,例如:
input_tensor = Input(shape=(None, 64, 64, 1))
pooled = AveragePooling2D()(input_tensor)
net = myCustomLayer(reference=pooled)(input_tensor)
...
如您所见,该层会序列化引用的名称,并在运行时重建链接(get_tensor_by_name)。整个结构似乎过于复杂,所以我想知道,在Keras中有没有一种规范的方法可以做到这一点?