在对keras模型进行子类化时,参数数量增加

时间:2019-07-18 11:57:10

标签: python tensorflow keras

我正在使用tf.keras定义以下子类:

import tensorflow as tf
from tensorflow.keras.layers import Conv3D, MaxPool3D, Flatten, Dense
from tensorflow.keras.layers import Input, BatchNormalization, ReLU
from tensorflow.keras.layers import AvgPool3D
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Concatenate, Add
from tensorflow.keras.estimator import model_to_estimator
from tensorflow.keras.layers import GlobalAveragePooling3D

class ConvBlock(tf.keras.Model):
    def __init__(self, filters, kernel_size, strides=(1, 1, 1), padding='valid', activation=True, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.activation = activation
        self.conv_1 = Conv3D(filters=self.filters, 
                             kernel_size=self.kernel_size, 
                             strides=self.strides, 
                             padding=self.padding, 
                             activation=None)    
        self.batch_norm_1 = BatchNormalization()
        self.relu_1 = ReLU(max_value=6)

    def call(self, inputs):
        conv = self.conv_1(inputs)
        batch_norm = self.batch_norm_1(conv)

        if self.activation:
            relu = self.relu_1(batch_norm)
            return relu
        else:
            return batch_norm

然后,我使用我上面定义的Model定义另一个ConvBlock子类。

class FireModule(tf.keras.Model):

    def __init__(self, n_hidden, **kwargs):
        super(FireModule, self).__init__(**kwargs)
        self.n_hidden = n_hidden
        self.squeeze_1 = ConvBlock(filters=self.n_hidden//4, kernel_size=(1, 1, 1), padding='same')
        self.expand_1 = ConvBlock(filters=self.n_hidden, kernel_size=(1, 1, 1), padding='same', activation=False)
        self.expand_2 = ConvBlock(filters=self.n_hidden, kernel_size=(5, 5, 5), padding='same', activation=False) 
        self.concat_1 = Concatenate()
        self.relu_1 = ReLU(max_value=6)

    def call(self, inputs):
        f1_sq_1 = self.squeeze_1(inputs)
        f1_e_1 = self.expand_1(f1_sq_1)
        f1_e_2 = self.expand_2(f1_e_1)
        concat_output = self.concat_1([f1_e_1, f1_e_2])
        relu = self.relu_1(concat_output)
        return relu

快速运行model.summary()会产生以下结果:

model_subclass = FireModule(n_hidden=32)
model_subclass.build((None, 128, 128, 50, 1))
model_subclass.summary()

Model: "fire_module"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv_block_2 (ConvBlock)     multiple                  48        
_________________________________________________________________
conv_block_3 (ConvBlock)     multiple                  416       
_________________________________________________________________
conv_block_4 (ConvBlock)     multiple                  128160    
_________________________________________________________________
concatenate (Concatenate)    multiple                  0         
_________________________________________________________________
re_lu_5 (ReLU)               multiple                  0         
=================================================================
Total params: 128,624
Trainable params: 128,480
Non-trainable params: 144

请注意“总参数”值。

但是,当我在tf.keras中定义模型时,参数数量要少得多:

def conv_block(inputs, filters, kernel_size, strides=(1, 1, 1),
                 padding='valid', activation=True, block_name='conv3d'):

    with tf.name_scope(block_name):
      conv = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides,
                    padding=padding, activation=None,
                    name='{}_conv'.format(block_name))(inputs)
      batch_norm = BatchNormalization(
          name='{}_batch_norm'.format(block_name))(conv)

      if activation:
        relu = ReLU(max_value=6, name='{}_relu'.format(block_name))(batch_norm)
        res_layer = relu
      else:
        res_layer = batch_norm
    return res_layer

def fire(inputs, n_hidden, block_name):

    with tf.name_scope(block_name):
      f1_sq_1 = conv_block(inputs, filters=n_hidden//4, kernel_size=(1, 1, 1),
                           padding='same',
                           block_name='{}_squeeze_1'.format(block_name))
      f1_e_1 = conv_block(f1_sq_1, filters=n_hidden, kernel_size=(1, 1, 1),
                          activation=False, padding='same',
                          block_name='{}_expand_1'.format(block_name))
      f1_e_2 = conv_block(f1_sq_1, filters=n_hidden, kernel_size=(5, 5, 5),
                          activation=False, padding='same',
                          block_name='{}_expand_2'.format(block_name))

      concat_output = Concatenate(
          name='{}_concatenate'.format(block_name))([f1_e_1, f1_e_2])
      relu = ReLU(
          max_value=6, name='{}_relu'.format(block_name))(concat_output)
    return relu

inputs = Input(shape = (128, 128, 50, 1))
fire_2 = fire(inputs, 32, 'fire_2')
model_keras = Model(inputs=inputs, outputs=fire_2)
model_keras.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 128, 128, 50 0                                            
__________________________________________________________________________________________________
fire_2_squeeze_1_conv (Conv3D)  (None, 128, 128, 50, 16          input_4[0][0]                    
__________________________________________________________________________________________________
fire_2_squeeze_1_batch_norm (Ba (None, 128, 128, 50, 32          fire_2_squeeze_1_conv[0][0]      
__________________________________________________________________________________________________
fire_2_squeeze_1_relu (ReLU)    (None, 128, 128, 50, 0           fire_2_squeeze_1_batch_norm[0][0]
__________________________________________________________________________________________________
fire_2_expand_1_conv (Conv3D)   (None, 128, 128, 50, 288         fire_2_squeeze_1_relu[0][0]      
__________________________________________________________________________________________________
fire_2_expand_2_conv (Conv3D)   (None, 128, 128, 50, 32032       fire_2_squeeze_1_relu[0][0]      
__________________________________________________________________________________________________
fire_2_expand_1_batch_norm (Bat (None, 128, 128, 50, 128         fire_2_expand_1_conv[0][0]       
__________________________________________________________________________________________________
fire_2_expand_2_batch_norm (Bat (None, 128, 128, 50, 128         fire_2_expand_2_conv[0][0]       
__________________________________________________________________________________________________
fire_2_concatenate (Concatenate (None, 128, 128, 50, 0           fire_2_expand_1_batch_norm[0][0] 
                                                                 fire_2_expand_2_batch_norm[0][0] 
__________________________________________________________________________________________________
fire_2_relu (ReLU)              (None, 128, 128, 50, 0           fire_2_concatenate[0][0]         
==================================================================================================
Total params: 32,624
Trainable params: 32,480
Non-trainable params: 144

子类化模型中的参数数目比keras模型大96,000。我的问题是为什么会这样?

0 个答案:

没有答案