我有基于MobileNet的模型来执行回归任务:
def MobileNet_v1():
# Keras 2.1.6
mobilenet = MobileNet(input_shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS),
alpha=1.0,
depth_multiplier=1,
include_top=False,
weights='imagenet'
)
x = Flatten()(mobilenet.output)
x = Dropout(0.5)(x)
x = Dense(config.N_LANDMARKS * 2, activation='linear')(x)
# -------------------------------------------------------
model = Model(inputs=mobilenet.input, outputs=x)
optimizer = Adadelta()
model.compile(optimizer=optimizer, loss=mae_loss)
model.summary()
import sys
sys.exit()
return model
网络结构:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 128, 128, 3) 0
_________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 130, 130, 3) 0
_________________________________________________________________
conv1 (Conv2D) (None, 64, 64, 32) 864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 64, 64, 32) 128
_________________________________________________________________
conv1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pad_1 (ZeroPadding2D) (None, 66, 66, 32) 0
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D) (None, 64, 64, 32) 288
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 64, 64, 32) 128
_________________________________________________________________
conv_dw_1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pw_1 (Conv2D) (None, 64, 64, 64) 2048
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 64, 64, 64) 256
_________________________________________________________________
conv_pw_1_relu (Activation) (None, 64, 64, 64) 0
_________________________________________________________________
conv_pad_2 (ZeroPadding2D) (None, 66, 66, 64) 0
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D) (None, 32, 32, 64) 576
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 32, 32, 64) 256
_________________________________________________________________
conv_dw_2_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv_pw_2 (Conv2D) (None, 32, 32, 128) 8192
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_2_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_3 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D) (None, 32, 32, 128) 1152
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_dw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pw_3 (Conv2D) (None, 32, 32, 128) 16384
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_4 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D) (None, 16, 16, 128) 1152
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 16, 16, 128) 512
_________________________________________________________________
conv_dw_4_relu (Activation) (None, 16, 16, 128) 0
_________________________________________________________________
conv_pw_4 (Conv2D) (None, 16, 16, 256) 32768
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_4_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_5 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D) (None, 16, 16, 256) 2304
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_dw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pw_5 (Conv2D) (None, 16, 16, 256) 65536
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_6 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D) (None, 8, 8, 256) 2304
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 8, 8, 256) 1024
_________________________________________________________________
conv_dw_6_relu (Activation) (None, 8, 8, 256) 0
_________________________________________________________________
conv_pw_6 (Conv2D) (None, 8, 8, 512) 131072
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_6_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_7 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_7 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_8 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_8 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_9 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_9 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_10 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_10 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_11 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_11 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_12 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 4, 4, 512) 4608
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 4, 4, 512) 2048
_________________________________________________________________
conv_dw_12_relu (Activation) (None, 4, 4, 512) 0
_________________________________________________________________
conv_pw_12 (Conv2D) (None, 4, 4, 1024) 524288
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_12_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pad_13 (ZeroPadding2D) (None, 6, 6, 1024) 0
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 4, 4, 1024) 9216
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_dw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pw_13 (Conv2D) (None, 4, 4, 1024) 1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 16384) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 16384) 0
_________________________________________________________________
dense_1 (Dense) (None, 156) 2556060
=================================================================
Total params: 5,784,924
Trainable params: 5,763,036
Non-trainable params: 21,888
_________________________________________________________________
我们可以看到,大约一半的网络参数位于最后一个密集层。所以我的问题是我是否已经训练过网络如何减小模型尺寸?我已经测试了全局平均池而不是密集层,并且对于我的回归任务来说,它表现不佳,所以这不是一个选择,因此我期待着减少密集层大小或稀疏密集层的事情。
更新:
具有全球平均池的网络示例:
def MobileNet_v2():
# MobileNet with GAP layer on top
# Keras 2.1.6
mobilenet = MobileNet(input_shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS),
alpha=1.0,
depth_multiplier=1,
include_top=False,
weights='imagenet'
)
x = Conv2D(filters=config.N_LANDMARKS * 2, kernel_size=(1,1), activation='linear')(mobilenet.output)
x = GlobalAveragePooling2D()(x)
# -------------------------------------------------------
model = Model(inputs=mobilenet.input, outputs=x)
optimizer = Adadelta()
model.compile(optimizer=optimizer, loss=mae_loss)
model.summary()
import sys
sys.exit()
return model
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) (None, 128, 128, 3) 0
_________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 130, 130, 3) 0
_________________________________________________________________
conv1 (Conv2D) (None, 64, 64, 32) 864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 64, 64, 32) 128
_________________________________________________________________
conv1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pad_1 (ZeroPadding2D) (None, 66, 66, 32) 0
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D) (None, 64, 64, 32) 288
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 64, 64, 32) 128
_________________________________________________________________
conv_dw_1_relu (Activation) (None, 64, 64, 32) 0
_________________________________________________________________
conv_pw_1 (Conv2D) (None, 64, 64, 64) 2048
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 64, 64, 64) 256
_________________________________________________________________
conv_pw_1_relu (Activation) (None, 64, 64, 64) 0
_________________________________________________________________
conv_pad_2 (ZeroPadding2D) (None, 66, 66, 64) 0
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D) (None, 32, 32, 64) 576
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 32, 32, 64) 256
_________________________________________________________________
conv_dw_2_relu (Activation) (None, 32, 32, 64) 0
_________________________________________________________________
conv_pw_2 (Conv2D) (None, 32, 32, 128) 8192
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_2_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_3 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D) (None, 32, 32, 128) 1152
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_dw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pw_3 (Conv2D) (None, 32, 32, 128) 16384
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 32, 32, 128) 512
_________________________________________________________________
conv_pw_3_relu (Activation) (None, 32, 32, 128) 0
_________________________________________________________________
conv_pad_4 (ZeroPadding2D) (None, 34, 34, 128) 0
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D) (None, 16, 16, 128) 1152
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 16, 16, 128) 512
_________________________________________________________________
conv_dw_4_relu (Activation) (None, 16, 16, 128) 0
_________________________________________________________________
conv_pw_4 (Conv2D) (None, 16, 16, 256) 32768
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_4_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_5 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D) (None, 16, 16, 256) 2304
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_dw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pw_5 (Conv2D) (None, 16, 16, 256) 65536
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 16, 16, 256) 1024
_________________________________________________________________
conv_pw_5_relu (Activation) (None, 16, 16, 256) 0
_________________________________________________________________
conv_pad_6 (ZeroPadding2D) (None, 18, 18, 256) 0
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D) (None, 8, 8, 256) 2304
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 8, 8, 256) 1024
_________________________________________________________________
conv_dw_6_relu (Activation) (None, 8, 8, 256) 0
_________________________________________________________________
conv_pw_6 (Conv2D) (None, 8, 8, 512) 131072
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_6_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_7 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_7 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_7_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_8 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_8 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_8_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_9 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_9 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_9_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_10 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_10 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_10_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_11 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 8, 8, 512) 4608
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_dw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pw_11 (Conv2D) (None, 8, 8, 512) 262144
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 8, 8, 512) 2048
_________________________________________________________________
conv_pw_11_relu (Activation) (None, 8, 8, 512) 0
_________________________________________________________________
conv_pad_12 (ZeroPadding2D) (None, 10, 10, 512) 0
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 4, 4, 512) 4608
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 4, 4, 512) 2048
_________________________________________________________________
conv_dw_12_relu (Activation) (None, 4, 4, 512) 0
_________________________________________________________________
conv_pw_12 (Conv2D) (None, 4, 4, 1024) 524288
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_12_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pad_13 (ZeroPadding2D) (None, 6, 6, 1024) 0
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 4, 4, 1024) 9216
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_dw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv_pw_13 (Conv2D) (None, 4, 4, 1024) 1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 4, 4, 1024) 4096
_________________________________________________________________
conv_pw_13_relu (Activation) (None, 4, 4, 1024) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 4, 4, 156) 159900
_________________________________________________________________
global_average_pooling2d_1 ( (None, 156) 0
=================================================================
Total params: 3,388,764
Trainable params: 3,366,876
Non-trainable params: 21,888
答案 0 :(得分:0)
关于减少致密层的大小:
压缩内部产品(完全连接)的权重矩阵W 截断的SVD层
https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py