我正在使用tf.keras
定义以下子类:
import tensorflow as tf
from tensorflow.keras.layers import Conv3D, MaxPool3D, Flatten, Dense
from tensorflow.keras.layers import Input, BatchNormalization, ReLU
from tensorflow.keras.layers import AvgPool3D
from tensorflow.keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Concatenate, Add
from tensorflow.keras.estimator import model_to_estimator
from tensorflow.keras.layers import GlobalAveragePooling3D
class ConvBlock(tf.keras.Model):
def __init__(self, filters, kernel_size, strides=(1, 1, 1), padding='valid', activation=True, **kwargs):
super(ConvBlock, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.activation = activation
self.conv_1 = Conv3D(filters=self.filters,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
activation=None)
self.batch_norm_1 = BatchNormalization()
self.relu_1 = ReLU(max_value=6)
def call(self, inputs):
conv = self.conv_1(inputs)
batch_norm = self.batch_norm_1(conv)
if self.activation:
relu = self.relu_1(batch_norm)
return relu
else:
return batch_norm
然后,我使用我上面定义的Model
定义另一个ConvBlock
子类。
class FireModule(tf.keras.Model):
def __init__(self, n_hidden, **kwargs):
super(FireModule, self).__init__(**kwargs)
self.n_hidden = n_hidden
self.squeeze_1 = ConvBlock(filters=self.n_hidden//4, kernel_size=(1, 1, 1), padding='same')
self.expand_1 = ConvBlock(filters=self.n_hidden, kernel_size=(1, 1, 1), padding='same', activation=False)
self.expand_2 = ConvBlock(filters=self.n_hidden, kernel_size=(5, 5, 5), padding='same', activation=False)
self.concat_1 = Concatenate()
self.relu_1 = ReLU(max_value=6)
def call(self, inputs):
f1_sq_1 = self.squeeze_1(inputs)
f1_e_1 = self.expand_1(f1_sq_1)
f1_e_2 = self.expand_2(f1_e_1)
concat_output = self.concat_1([f1_e_1, f1_e_2])
relu = self.relu_1(concat_output)
return relu
快速运行model.summary()
会产生以下结果:
model_subclass = FireModule(n_hidden=32)
model_subclass.build((None, 128, 128, 50, 1))
model_subclass.summary()
Model: "fire_module"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv_block_2 (ConvBlock) multiple 48
_________________________________________________________________
conv_block_3 (ConvBlock) multiple 416
_________________________________________________________________
conv_block_4 (ConvBlock) multiple 128160
_________________________________________________________________
concatenate (Concatenate) multiple 0
_________________________________________________________________
re_lu_5 (ReLU) multiple 0
=================================================================
Total params: 128,624
Trainable params: 128,480
Non-trainable params: 144
请注意“总参数”值。
但是,当我在tf.keras
中定义模型时,参数数量要少得多:
def conv_block(inputs, filters, kernel_size, strides=(1, 1, 1),
padding='valid', activation=True, block_name='conv3d'):
with tf.name_scope(block_name):
conv = Conv3D(filters=filters, kernel_size=kernel_size, strides=strides,
padding=padding, activation=None,
name='{}_conv'.format(block_name))(inputs)
batch_norm = BatchNormalization(
name='{}_batch_norm'.format(block_name))(conv)
if activation:
relu = ReLU(max_value=6, name='{}_relu'.format(block_name))(batch_norm)
res_layer = relu
else:
res_layer = batch_norm
return res_layer
def fire(inputs, n_hidden, block_name):
with tf.name_scope(block_name):
f1_sq_1 = conv_block(inputs, filters=n_hidden//4, kernel_size=(1, 1, 1),
padding='same',
block_name='{}_squeeze_1'.format(block_name))
f1_e_1 = conv_block(f1_sq_1, filters=n_hidden, kernel_size=(1, 1, 1),
activation=False, padding='same',
block_name='{}_expand_1'.format(block_name))
f1_e_2 = conv_block(f1_sq_1, filters=n_hidden, kernel_size=(5, 5, 5),
activation=False, padding='same',
block_name='{}_expand_2'.format(block_name))
concat_output = Concatenate(
name='{}_concatenate'.format(block_name))([f1_e_1, f1_e_2])
relu = ReLU(
max_value=6, name='{}_relu'.format(block_name))(concat_output)
return relu
inputs = Input(shape = (128, 128, 50, 1))
fire_2 = fire(inputs, 32, 'fire_2')
model_keras = Model(inputs=inputs, outputs=fire_2)
model_keras.summary()
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, 128, 128, 50 0
__________________________________________________________________________________________________
fire_2_squeeze_1_conv (Conv3D) (None, 128, 128, 50, 16 input_4[0][0]
__________________________________________________________________________________________________
fire_2_squeeze_1_batch_norm (Ba (None, 128, 128, 50, 32 fire_2_squeeze_1_conv[0][0]
__________________________________________________________________________________________________
fire_2_squeeze_1_relu (ReLU) (None, 128, 128, 50, 0 fire_2_squeeze_1_batch_norm[0][0]
__________________________________________________________________________________________________
fire_2_expand_1_conv (Conv3D) (None, 128, 128, 50, 288 fire_2_squeeze_1_relu[0][0]
__________________________________________________________________________________________________
fire_2_expand_2_conv (Conv3D) (None, 128, 128, 50, 32032 fire_2_squeeze_1_relu[0][0]
__________________________________________________________________________________________________
fire_2_expand_1_batch_norm (Bat (None, 128, 128, 50, 128 fire_2_expand_1_conv[0][0]
__________________________________________________________________________________________________
fire_2_expand_2_batch_norm (Bat (None, 128, 128, 50, 128 fire_2_expand_2_conv[0][0]
__________________________________________________________________________________________________
fire_2_concatenate (Concatenate (None, 128, 128, 50, 0 fire_2_expand_1_batch_norm[0][0]
fire_2_expand_2_batch_norm[0][0]
__________________________________________________________________________________________________
fire_2_relu (ReLU) (None, 128, 128, 50, 0 fire_2_concatenate[0][0]
==================================================================================================
Total params: 32,624
Trainable params: 32,480
Non-trainable params: 144
子类化模型中的参数数目比keras模型大96,000。我的问题是为什么会这样?