使用Conv Layers创建Keras Layer子类

时间:2019-07-17 17:06:55

标签: python tensorflow keras

我想创建类似于以下功能的自定义tf.keras.layers.Layer

def conv_block(inputs, filters, kernel_size, strides=(1, 1, 1),
                 padding='valid', activation=True, block_name='conv3d'):

    with tf.name_scope(block_name):
      conv = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides,
                    padding=padding, activation=None,
                    name='{}_conv'.format(block_name))(inputs)
      batch_norm = BatchNormalization(
          name='{}_batch_norm'.format(block_name))(conv)

      if activation:
        relu = ReLU(max_value=6, name='{}_relu'.format(block_name))(batch_norm)
        res_layer = relu
      else:
        res_layer = batch_norm
    return res_layer

我仔细阅读了herehere的可用文档,然后创建了下面的类:

class ConvBlock(tf.keras.layers.Layer):

    def __init__(self, filters, kernel_size, strides=(1, 1, 1), padding='valid', activation=True, **kwargs):
        super(ConvBlock, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.activation = activation

        self.conv_1 = Conv3D(filters=self.filters, 
                             kernel_size=self.kernel_size, 
                             strides=self.strides, 
                             padding=self.padding, 
                             activation=None)

        self.batch_norm_1 = BatchNormalization()
        self.relu_1 = ReLU(max_value=6)

    def call(self, inputs):
        conv = self.conv_1(inputs)
        batch_norm = self.batch_norm_1(conv)

        if self.activation:
            relu = self.relu_1(batch_norm)
            return relu
        else:
            return batch_norm

我想在整个模型中多次使用此Layer。我对此有几个疑问:

  1. 文档中提到在add_weights()方法中使用build()。但是在这种情况下是否有必要?
  2. 我是否需要包含build()方法?
  3. 如何获取图层的输出形状?该文档提到使用以下功能:

    def compute_output_shape(self,input_shape):     形状= tf.TensorShape(input_shape).as_list()     形状[-1] = self.output_dim     返回tf.TensorShape(shape)

如何使用此函数计算输出层的形状?

1 个答案:

答案 0 :(得分:0)

也许你可以直接使用一个函数来封装你的重复操作而不是子类化层,只有当你认为你需要使用权重或初始化权重的模式时才使用子类化,因为这是后者的正确方法。

示例:

def simple_conv(x):
   x = Conv2d(x)
   x = Bathcnorm(x)
   return x