如何在Keras中实现CRelu?

时间:2019-02-27 17:40:53

标签: python tensorflow keras deep-learning

我正在尝试在Keras中实现CRelu层

似乎可行的一种选择是使用Lambda层:

def _crelu(x):
    x = tf.nn.crelu(x, axis=-1)
    return x

def _conv_bn_crelu(x, n_filters, kernel_size):
    x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=-1)(x)
    x = Lambda(_crelu)(x)
    return x

但是我不知道Lamda层会在训练或推理过程中引入一些开销吗?

我的第二个尝试是创建围绕tf.nn.crelu

的keras层。
class CRelu(Layer):
    def __init__(self, **kwargs):
        super(CRelu, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CRelu, self).build(input_shape)

    def call(self, x):
        x = tf.nn.crelu(x, axis=-1)
        return x

    def compute_output_shape(self, input_shape):
        output_shape = list(input_shape)
        output_shape[-1] = output_shape[-1] * 2
        output_shape = tuple(output_shape)
        return output_shape

def _conv_bn_crelu(x, n_filters, kernel_size):
    x = Conv2D(filters=n_filters, kernel_size=kernel_size, strides=(1, 1), padding='same')(x)
    x = BatchNormalization(axis=-1)(x)
    x = CRelu()(x)
    return x

哪个版本会更有效?

如果可能,也期待纯Keras实施。

1 个答案:

答案 0 :(得分:0)

我认为这两种实现在速度方面没有显着差异。

Lambda实现实际上是最简单的,但通常编写一个自定义图层会更好,特别是对于模型保存和加载( get_config 方法)而言。

但是在这种情况下,这无关紧要,因为CReLU是微不足道的,并且不需要保存和恢复参数。您实际上可以按照以下代码存储axis参数。这样,将在加载模型时自动检索它。

class CRelu(Layer):
    def __init__(self, axis=-1, **kwargs):
        self.axis = axis 
        super(CRelu, self).__init__(**kwargs)

    def build(self, input_shape):
        super(CRelu, self).build(input_shape)

    def call(self, x):
        x = tf.nn.crelu(x, axis=self.axis)
        return x

    def compute_output_shape(self, input_shape):
        output_shape = list(input_shape)
        output_shape[-1] = output_shape[-1] * 2
        output_shape = tuple(output_shape)
        return output_shape

    def get_config(self, input_shape):
        config = {'axis': self.axis, }
        base_config = super(CReLU, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))