如何重用“复合”Keras层?

时间:2017-05-19 19:41:47

标签: python keras

所以,我有这个小助手功能:

def ResConv(input, size):
    return BatchNormalization()(Add()([
        GLU()(Conv1D(size*2, 5, padding='causal',)(input)),
        input
    ]))

它创建了一起使用的特定层序列;很清楚。

然而,现在我意识到我需要在不同的输入上重用相同的层;也就是说,我需要这样的东西

my_res_conv = ResConv(100)
layer_a = my_res_conv(input_a)
layer_b = my_res_conv(input_b)
concat = concatenate([layer_a, layer_b])

并且layer_alayer_b分享权重。

我该怎么做?我是否必须编写自定义图层?我之前从未这样做过,而且我不确定如何处理这种情况。

1 个答案:

答案 0 :(得分:1)

我最终真正制作了这样的自定义类:

class ResConv():
    def __init__(self, size):
        self.conv = Conv1D(size*2, 5, padding='causal')
        self.batchnorm = BatchNormalization()
        super(ResConv, self).__init__()

    def __call__(self, inputs):
        return self.batchnorm(Add()([
            GLU()(self.conv(inputs)),
            inputs
        ]))

基本上,您在__init__中初始化图层,并在__call__中写出整个计算序列;这样,每次调用它时,您的类都会将相同的图层重新应用于新输入。