在keras中序列化参考张量的规范方法

时间:2018-06-26 21:13:34

标签: keras keras-layer

我们建立了一个自定义层,该层引用了另一层的形状张量。是否有规范的方法来序列化此属性?自定义图层如下所示:

class myCustomLayer(Layer):
def __init__(self, reference=False, reference_name=False,  **kwargs):
    self.reference      = reference
    self.reference_name = reference_name
    super(myCustomLayer, self).__init__(**kwargs)

def build(self, input_shape):
    super(myCustomLayer, self).build(input_shape)

def call(self, x):
    if (self.reference_name != False) and (self.reference == False):
        self.reference = tf.get_default_graph().get_tensor_by_name(self.reference_name)
    reference_shape = K.shape(self.reference)
    # do stuff with x in context of reference_shape

def get_config(self):
    if self.reference != False:
          reference_name = self.reference.name
    config = {'target_shape':   self.target_shape,
              'reference_name': reference_name}
    base_config = super(myCustomLayer, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

您将创建myCustomLayer的实例,例如:

input_tensor = Input(shape=(None, 64, 64, 1))
pooled = AveragePooling2D()(input_tensor)
net = myCustomLayer(reference=pooled)(input_tensor)
...

如您所见,该层会序列化引用的名称,并在运行时重建链接(get_tensor_by_name)。整个结构似乎过于复杂,所以我想知道,在Keras中有没有一种规范的方法可以做到这一点?

0 个答案:

没有答案