保存重量时,我的rcnn模型太大,如何缩小尺寸?

时间:2019-06-12 16:37:23

标签: tensorflow lambda keras

当我save_weights()时,我的rcnn模型太大,接近1Gb。我想缩小它的大小。

我使用循环来模仿简单的rnn,但是输入是不同的。而且我需要用于堆栈输出的所有步骤,以便能够计算每一步的总损失。我试图用时间分布层重写它,但没有成功。你有什么建议吗?

    x_input = tf.keras.layers.Input((shape[1],shape[2], const.num_channels),name='x_input')
    y_init =  tf.keras.layers.Input((const.num_patches,2),name='y_init')
    dxs = []

    for i in range(const.num_iters_rnn):

        if i is 0:
            patches = tf.keras.layers.Lambda(extract_patches)([x_input,y_init])
        else:
            patches = tf.keras.layers.Lambda(extract_patches)([x_input,dxs[i-1]])

        conv2d1 = tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu')(patches)
        maxpool1 =  tf.keras.layers.MaxPooling2D()(conv2d1)
        conv2d2 =  tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu')(maxpool1)
        maxpool2 =  tf.keras.layers.MaxPooling2D()(conv2d2)
        crop = tf.keras.layers.Cropping2D(cropping=(const.crop_size, const.crop_size))(conv2d2)
        cnn = tf.keras.layers.concatenate([crop,maxpool2])

        cnn = tf.keras.layers.Lambda(reshape)(cnn)


        if i is 0:
            hidden_state = tf.keras.layers.Dense(const.numNeurons,activation='tanh')(cnn)
        else:
            concat = tf.keras.layers.concatenate([cnn,hidden_state],axis=1)
            hidden_state = tf.keras.layers.Dense(const.numNeurons,activation='tanh')(concat)

        hidden_state = tf.keras.layers.BatchNormalization()(hidden_state)

        prediction = tf.keras.layers.Dense(const.num_patches*2,activation=None)(hidden_state)

        prediction = tf.keras.layers.Dropout(0.5)(prediction)

        prediction_reshape = tf.keras.layers.Reshape((const.num_patches, 2))(prediction)

        if i is 0:
            prediction = tf.keras.layers.Add()([prediction_reshape, y_init])
            dxs.append(prediction)
        else:
            prediction = tf.keras.layers.Add()([prediction_reshape,dxs[i-1]])
            dxs.append(prediction)

    output = tf.keras.layers.Lambda(stack)(dxs)

    model = tf.keras.models.Model(inputs=[x_input, y_init], outputs=[output])

def extract_patches(inputs):
    list_patches = []
    for j in range(const.num_patches):
        patch_one = tf.image.extract_glimpse(inputs[0], [const.size_patch[0], const.size_patch[1]], inputs[1][:, j, :], centered=False, normalized=False, noise='zero')
        list_patches.append(patch_one)

    patches = tf.keras.backend.stack(list_patches,1)
    return tf.keras.backend.reshape(patches,(-1,patches.shape[2],patches.shape[3],patches.shape[4]))

def reshape(inputs):
    return  tf.keras.backend.reshape(inputs,(-1,const.num_patches*inputs.shape[1]*inputs.shape[2]*inputs.shape[3]))

def stack(inputs):
    return  tf.keras.backend.stack(inputs)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
x_input (InputLayer)            [(None, 255, 235, 1) 0
__________________________________________________________________________________________________
y_init (InputLayer)             [(None, 52, 2)]      0
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 26, 26, 1)    0           x_input[0][0]
                                                                 y_init[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 26, 26, 32)   320         lambda[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 13, 13, 32)   0           conv2d[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 13, 13, 32)   9248        max_pooling2d[0][0]
__________________________________________________________________________________________________
cropping2d (Cropping2D)         (None, 6, 6, 32)     0           conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 6, 6, 32)     0           conv2d_1[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 6, 6, 64)     0           cropping2d[0][0]
                                                                 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 119808)       0           concatenate[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 512)          61342208    lambda_1[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512)          2048        dense[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 104)          53352       batch_normalization[0][0]
__________________________________________________________________________________________________
dropout (Dropout)               (None, 104)          0           dense_1[0][0]
__________________________________________________________________________________________________
reshape (Reshape)               (None, 52, 2)        0           dropout[0][0]
__________________________________________________________________________________________________
add (Add)                       (None, 52, 2)        0           reshape[0][0]
                                                                 y_init[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 26, 26, 1)    0           x_input[0][0]
                                                                 add[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 26, 26, 32)   320         lambda_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 13, 13, 32)   0           conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 13, 13, 32)   9248        max_pooling2d_2[0][0]
__________________________________________________________________________________________________
cropping2d_1 (Cropping2D)       (None, 6, 6, 32)     0           conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 6, 6, 32)     0           conv2d_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 6, 6, 64)     0           cropping2d_1[0][0]
                                                                 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 119808)       0           concatenate_1[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 120320)       0           lambda_3[0][0]
                                                                 batch_normalization[0][0]
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 512)          61604352    concatenate_2[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 512)          2048        dense_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 104)          53352       batch_normalization_1[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 104)          0           dense_3[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 52, 2)        0           dropout_1[0][0]
__________________________________________________________________________________________________
add_1 (Add)                     (None, 52, 2)        0           reshape_1[0][0]
                                                                 add[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 26, 26, 1)    0           x_input[0][0]
                                                                 add_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 26, 26, 32)   320         lambda_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 13, 13, 32)   0           conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 13, 13, 32)   9248        max_pooling2d_4[0][0]
__________________________________________________________________________________________________
cropping2d_2 (Cropping2D)       (None, 6, 6, 32)     0           conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 6, 6, 32)     0           conv2d_5[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 6, 6, 64)     0           cropping2d_2[0][0]
                                                                 max_pooling2d_5[0][0]
__________________________________________________________________________________________________
lambda_5 (Lambda)               (None, 119808)       0           concatenate_3[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 120320)       0           lambda_5[0][0]
                                                                 batch_normalization_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 512)          61604352    concatenate_4[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 512)          2048        dense_4[0][0]
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 104)          53352       batch_normalization_2[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 104)          0           dense_5[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 52, 2)        0           dropout_2[0][0]
__________________________________________________________________________________________________
add_2 (Add)                     (None, 52, 2)        0           reshape_2[0][0]
                                                                 add_1[0][0]
__________________________________________________________________________________________________
lambda_6 (Lambda)               (None, 26, 26, 1)    0           x_input[0][0]
                                                                 add_2[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 26, 26, 32)   320         lambda_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D)  (None, 13, 13, 32)   0           conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 13, 13, 32)   9248        max_pooling2d_6[0][0]
__________________________________________________________________________________________________
cropping2d_3 (Cropping2D)       (None, 6, 6, 32)     0           conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D)  (None, 6, 6, 32)     0           conv2d_7[0][0]
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 6, 6, 64)     0           cropping2d_3[0][0]
                                                                 max_pooling2d_7[0][0]
__________________________________________________________________________________________________
lambda_7 (Lambda)               (None, 119808)       0           concatenate_5[0][0]
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 120320)       0           lambda_7[0][0]
                                                                 batch_normalization_2[0][0]
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 512)          61604352    concatenate_6[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 512)          2048        dense_6[0][0]
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 104)          53352       batch_normalization_3[0][0]      
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 104)          0           dense_7[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape)             (None, 52, 2)        0           dropout_3[0][0]
__________________________________________________________________________________________________
add_3 (Add)                     (None, 52, 2)        0           reshape_3[0][0]
                                                                 add_2[0][0]
__________________________________________________________________________________________________
lambda_8 (Lambda)               (4, None, 52, 2)     0           add[0][0]
                                                                 add_1[0][0]
                                                                 add_2[0][0]
                                                                 add_3[0][0]
==================================================================================================
Total params: 246,415,136
Trainable params: 246,411,040
Non-trainable params: 4,096

1 个答案:

答案 0 :(得分:1)

您必须减小型号的大小,因为对于该型号,1GB是合理的,但是有些解决方案这样做的目的是最终精度不会降低,在某些情况下还会增加。您可以搜索通过修剪来改善神经网络。