如何缩小现有网络的规模?

时间:2018-11-18 13:06:33

标签: python tensorflow machine-learning keras deep-learning

我有基于MobileNet的模型来执行回归任务:

def MobileNet_v1():
    # Keras 2.1.6
    mobilenet = MobileNet(input_shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS),
                          alpha=1.0,
                          depth_multiplier=1,
                          include_top=False,
                          weights='imagenet'
                          )

    x = Flatten()(mobilenet.output)
    x = Dropout(0.5)(x)
    x = Dense(config.N_LANDMARKS * 2, activation='linear')(x)

    # -------------------------------------------------------

    model = Model(inputs=mobilenet.input, outputs=x)
    optimizer = Adadelta()
    model.compile(optimizer=optimizer, loss=mae_loss)

    model.summary()
    import sys
    sys.exit()

    return model

网络结构:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 128, 128, 3)       0
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 130, 130, 3)       0
_________________________________________________________________
conv1 (Conv2D)               (None, 64, 64, 32)        864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 64, 64, 32)        128
_________________________________________________________________
conv1_relu (Activation)      (None, 64, 64, 32)        0
_________________________________________________________________
conv_pad_1 (ZeroPadding2D)   (None, 66, 66, 32)        0
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 64, 64, 32)        288
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 64, 64, 32)        128
_________________________________________________________________
conv_dw_1_relu (Activation)  (None, 64, 64, 32)        0
_________________________________________________________________
conv_pw_1 (Conv2D)           (None, 64, 64, 64)        2048
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 64, 64, 64)        256
_________________________________________________________________
conv_pw_1_relu (Activation)  (None, 64, 64, 64)        0
_________________________________________________________________
conv_pad_2 (ZeroPadding2D)   (None, 66, 66, 64)        0
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D)  (None, 32, 32, 64)        576
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 32, 32, 64)        256
_________________________________________________________________
conv_dw_2_relu (Activation)  (None, 32, 32, 64)        0
_________________________________________________________________
conv_pw_2 (Conv2D)           (None, 32, 32, 128)       8192
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 32, 32, 128)       512
_________________________________________________________________
conv_pw_2_relu (Activation)  (None, 32, 32, 128)       0
_________________________________________________________________
conv_pad_3 (ZeroPadding2D)   (None, 34, 34, 128)       0
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D)  (None, 32, 32, 128)       1152
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 32, 32, 128)       512
_________________________________________________________________
conv_dw_3_relu (Activation)  (None, 32, 32, 128)       0
_________________________________________________________________
conv_pw_3 (Conv2D)           (None, 32, 32, 128)       16384
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 32, 32, 128)       512
_________________________________________________________________
conv_pw_3_relu (Activation)  (None, 32, 32, 128)       0
_________________________________________________________________
conv_pad_4 (ZeroPadding2D)   (None, 34, 34, 128)       0
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D)  (None, 16, 16, 128)       1152
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 16, 16, 128)       512
_________________________________________________________________
conv_dw_4_relu (Activation)  (None, 16, 16, 128)       0
_________________________________________________________________
conv_pw_4 (Conv2D)           (None, 16, 16, 256)       32768
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 16, 16, 256)       1024
_________________________________________________________________
conv_pw_4_relu (Activation)  (None, 16, 16, 256)       0
_________________________________________________________________
conv_pad_5 (ZeroPadding2D)   (None, 18, 18, 256)       0
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D)  (None, 16, 16, 256)       2304
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 16, 16, 256)       1024
_________________________________________________________________
conv_dw_5_relu (Activation)  (None, 16, 16, 256)       0
_________________________________________________________________
conv_pw_5 (Conv2D)           (None, 16, 16, 256)       65536
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 16, 16, 256)       1024
_________________________________________________________________
conv_pw_5_relu (Activation)  (None, 16, 16, 256)       0
_________________________________________________________________
conv_pad_6 (ZeroPadding2D)   (None, 18, 18, 256)       0
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D)  (None, 8, 8, 256)         2304
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 8, 8, 256)         1024
_________________________________________________________________
conv_dw_6_relu (Activation)  (None, 8, 8, 256)         0
_________________________________________________________________
conv_pw_6 (Conv2D)           (None, 8, 8, 512)         131072
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_6_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_7 (ZeroPadding2D)   (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D)  (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_7_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_7 (Conv2D)           (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_7_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_8 (ZeroPadding2D)   (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D)  (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_8_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_8 (Conv2D)           (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_8_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_9 (ZeroPadding2D)   (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D)  (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_9_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_9 (Conv2D)           (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_9_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_10 (ZeroPadding2D)  (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_10_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_10 (Conv2D)          (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_10_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_11 (ZeroPadding2D)  (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_11_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_11 (Conv2D)          (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_11_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_12 (ZeroPadding2D)  (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 4, 4, 512)         4608
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 4, 4, 512)         2048
_________________________________________________________________
conv_dw_12_relu (Activation) (None, 4, 4, 512)         0
_________________________________________________________________
conv_pw_12 (Conv2D)          (None, 4, 4, 1024)        524288
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 4, 4, 1024)        4096
_________________________________________________________________
conv_pw_12_relu (Activation) (None, 4, 4, 1024)        0
_________________________________________________________________
conv_pad_13 (ZeroPadding2D)  (None, 6, 6, 1024)        0
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 4, 4, 1024)        9216
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 4, 4, 1024)        4096
_________________________________________________________________
conv_dw_13_relu (Activation) (None, 4, 4, 1024)        0
_________________________________________________________________
conv_pw_13 (Conv2D)          (None, 4, 4, 1024)        1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 4, 4, 1024)        4096
_________________________________________________________________
conv_pw_13_relu (Activation) (None, 4, 4, 1024)        0
_________________________________________________________________
flatten_1 (Flatten)          (None, 16384)             0
_________________________________________________________________
dropout_1 (Dropout)          (None, 16384)             0
_________________________________________________________________
dense_1 (Dense)              (None, 156)               2556060
=================================================================
Total params: 5,784,924
Trainable params: 5,763,036
Non-trainable params: 21,888
_________________________________________________________________

我们可以看到,大约一半的网络参数位于最后一个密集层。所以我的问题是我是否已经训练过网络如何减小模型尺寸?我已经测试了全局平均池而不是密集层,并且对于我的回归任务来说,它表现不佳,所以这不是一个选择,因此我期待着减少密集层大小或稀疏密集层的事情。

更新

具有全球平均池的网络示例:

def MobileNet_v2():
    # MobileNet with GAP layer on top

    # Keras 2.1.6
    mobilenet = MobileNet(input_shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS),
                          alpha=1.0,
                          depth_multiplier=1,
                          include_top=False,
                          weights='imagenet'
                          )

    x = Conv2D(filters=config.N_LANDMARKS * 2, kernel_size=(1,1), activation='linear')(mobilenet.output)
    x = GlobalAveragePooling2D()(x)

    # -------------------------------------------------------

    model = Model(inputs=mobilenet.input, outputs=x)
    optimizer = Adadelta()
    model.compile(optimizer=optimizer, loss=mae_loss)

    model.summary()
    import sys
    sys.exit()

    return model


_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 128, 128, 3)       0
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 130, 130, 3)       0
_________________________________________________________________
conv1 (Conv2D)               (None, 64, 64, 32)        864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 64, 64, 32)        128
_________________________________________________________________
conv1_relu (Activation)      (None, 64, 64, 32)        0
_________________________________________________________________
conv_pad_1 (ZeroPadding2D)   (None, 66, 66, 32)        0
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 64, 64, 32)        288
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 64, 64, 32)        128
_________________________________________________________________
conv_dw_1_relu (Activation)  (None, 64, 64, 32)        0
_________________________________________________________________
conv_pw_1 (Conv2D)           (None, 64, 64, 64)        2048
_________________________________________________________________
conv_pw_1_bn (BatchNormaliza (None, 64, 64, 64)        256
_________________________________________________________________
conv_pw_1_relu (Activation)  (None, 64, 64, 64)        0
_________________________________________________________________
conv_pad_2 (ZeroPadding2D)   (None, 66, 66, 64)        0
_________________________________________________________________
conv_dw_2 (DepthwiseConv2D)  (None, 32, 32, 64)        576
_________________________________________________________________
conv_dw_2_bn (BatchNormaliza (None, 32, 32, 64)        256
_________________________________________________________________
conv_dw_2_relu (Activation)  (None, 32, 32, 64)        0
_________________________________________________________________
conv_pw_2 (Conv2D)           (None, 32, 32, 128)       8192
_________________________________________________________________
conv_pw_2_bn (BatchNormaliza (None, 32, 32, 128)       512
_________________________________________________________________
conv_pw_2_relu (Activation)  (None, 32, 32, 128)       0
_________________________________________________________________
conv_pad_3 (ZeroPadding2D)   (None, 34, 34, 128)       0
_________________________________________________________________
conv_dw_3 (DepthwiseConv2D)  (None, 32, 32, 128)       1152
_________________________________________________________________
conv_dw_3_bn (BatchNormaliza (None, 32, 32, 128)       512
_________________________________________________________________
conv_dw_3_relu (Activation)  (None, 32, 32, 128)       0
_________________________________________________________________
conv_pw_3 (Conv2D)           (None, 32, 32, 128)       16384
_________________________________________________________________
conv_pw_3_bn (BatchNormaliza (None, 32, 32, 128)       512
_________________________________________________________________
conv_pw_3_relu (Activation)  (None, 32, 32, 128)       0
_________________________________________________________________
conv_pad_4 (ZeroPadding2D)   (None, 34, 34, 128)       0
_________________________________________________________________
conv_dw_4 (DepthwiseConv2D)  (None, 16, 16, 128)       1152
_________________________________________________________________
conv_dw_4_bn (BatchNormaliza (None, 16, 16, 128)       512
_________________________________________________________________
conv_dw_4_relu (Activation)  (None, 16, 16, 128)       0
_________________________________________________________________
conv_pw_4 (Conv2D)           (None, 16, 16, 256)       32768
_________________________________________________________________
conv_pw_4_bn (BatchNormaliza (None, 16, 16, 256)       1024
_________________________________________________________________
conv_pw_4_relu (Activation)  (None, 16, 16, 256)       0
_________________________________________________________________
conv_pad_5 (ZeroPadding2D)   (None, 18, 18, 256)       0
_________________________________________________________________
conv_dw_5 (DepthwiseConv2D)  (None, 16, 16, 256)       2304
_________________________________________________________________
conv_dw_5_bn (BatchNormaliza (None, 16, 16, 256)       1024
_________________________________________________________________
conv_dw_5_relu (Activation)  (None, 16, 16, 256)       0
_________________________________________________________________
conv_pw_5 (Conv2D)           (None, 16, 16, 256)       65536
_________________________________________________________________
conv_pw_5_bn (BatchNormaliza (None, 16, 16, 256)       1024
_________________________________________________________________
conv_pw_5_relu (Activation)  (None, 16, 16, 256)       0
_________________________________________________________________
conv_pad_6 (ZeroPadding2D)   (None, 18, 18, 256)       0
_________________________________________________________________
conv_dw_6 (DepthwiseConv2D)  (None, 8, 8, 256)         2304
_________________________________________________________________
conv_dw_6_bn (BatchNormaliza (None, 8, 8, 256)         1024
_________________________________________________________________
conv_dw_6_relu (Activation)  (None, 8, 8, 256)         0
_________________________________________________________________
conv_pw_6 (Conv2D)           (None, 8, 8, 512)         131072
_________________________________________________________________
conv_pw_6_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_6_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_7 (ZeroPadding2D)   (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_7 (DepthwiseConv2D)  (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_7_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_7_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_7 (Conv2D)           (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_7_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_7_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_8 (ZeroPadding2D)   (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_8 (DepthwiseConv2D)  (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_8_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_8_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_8 (Conv2D)           (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_8_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_8_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_9 (ZeroPadding2D)   (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_9 (DepthwiseConv2D)  (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_9_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_9_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_9 (Conv2D)           (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_9_bn (BatchNormaliza (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_9_relu (Activation)  (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_10 (ZeroPadding2D)  (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_10 (DepthwiseConv2D) (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_10_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_10_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_10 (Conv2D)          (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_10_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_10_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_11 (ZeroPadding2D)  (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_11 (DepthwiseConv2D) (None, 8, 8, 512)         4608
_________________________________________________________________
conv_dw_11_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_dw_11_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pw_11 (Conv2D)          (None, 8, 8, 512)         262144
_________________________________________________________________
conv_pw_11_bn (BatchNormaliz (None, 8, 8, 512)         2048
_________________________________________________________________
conv_pw_11_relu (Activation) (None, 8, 8, 512)         0
_________________________________________________________________
conv_pad_12 (ZeroPadding2D)  (None, 10, 10, 512)       0
_________________________________________________________________
conv_dw_12 (DepthwiseConv2D) (None, 4, 4, 512)         4608
_________________________________________________________________
conv_dw_12_bn (BatchNormaliz (None, 4, 4, 512)         2048
_________________________________________________________________
conv_dw_12_relu (Activation) (None, 4, 4, 512)         0
_________________________________________________________________
conv_pw_12 (Conv2D)          (None, 4, 4, 1024)        524288
_________________________________________________________________
conv_pw_12_bn (BatchNormaliz (None, 4, 4, 1024)        4096
_________________________________________________________________
conv_pw_12_relu (Activation) (None, 4, 4, 1024)        0
_________________________________________________________________
conv_pad_13 (ZeroPadding2D)  (None, 6, 6, 1024)        0
_________________________________________________________________
conv_dw_13 (DepthwiseConv2D) (None, 4, 4, 1024)        9216
_________________________________________________________________
conv_dw_13_bn (BatchNormaliz (None, 4, 4, 1024)        4096
_________________________________________________________________
conv_dw_13_relu (Activation) (None, 4, 4, 1024)        0
_________________________________________________________________
conv_pw_13 (Conv2D)          (None, 4, 4, 1024)        1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 4, 4, 1024)        4096
_________________________________________________________________
conv_pw_13_relu (Activation) (None, 4, 4, 1024)        0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 4, 4, 156)         159900
_________________________________________________________________
global_average_pooling2d_1 ( (None, 156)               0
=================================================================
Total params: 3,388,764
Trainable params: 3,366,876
Non-trainable params: 21,888

1 个答案:

答案 0 :(得分:0)

关于减少致密层的大小:

  

压缩内部产品(完全连接)的权重矩阵W   截断的SVD层

https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py