当我save_weights()时,我的rcnn模型太大,接近1Gb。我想缩小它的大小。
我使用循环来模仿简单的rnn,但是输入是不同的。而且我需要用于堆栈输出的所有步骤,以便能够计算每一步的总损失。我试图用时间分布层重写它,但没有成功。你有什么建议吗?
x_input = tf.keras.layers.Input((shape[1],shape[2], const.num_channels),name='x_input')
y_init = tf.keras.layers.Input((const.num_patches,2),name='y_init')
dxs = []
for i in range(const.num_iters_rnn):
if i is 0:
patches = tf.keras.layers.Lambda(extract_patches)([x_input,y_init])
else:
patches = tf.keras.layers.Lambda(extract_patches)([x_input,dxs[i-1]])
conv2d1 = tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu')(patches)
maxpool1 = tf.keras.layers.MaxPooling2D()(conv2d1)
conv2d2 = tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu')(maxpool1)
maxpool2 = tf.keras.layers.MaxPooling2D()(conv2d2)
crop = tf.keras.layers.Cropping2D(cropping=(const.crop_size, const.crop_size))(conv2d2)
cnn = tf.keras.layers.concatenate([crop,maxpool2])
cnn = tf.keras.layers.Lambda(reshape)(cnn)
if i is 0:
hidden_state = tf.keras.layers.Dense(const.numNeurons,activation='tanh')(cnn)
else:
concat = tf.keras.layers.concatenate([cnn,hidden_state],axis=1)
hidden_state = tf.keras.layers.Dense(const.numNeurons,activation='tanh')(concat)
hidden_state = tf.keras.layers.BatchNormalization()(hidden_state)
prediction = tf.keras.layers.Dense(const.num_patches*2,activation=None)(hidden_state)
prediction = tf.keras.layers.Dropout(0.5)(prediction)
prediction_reshape = tf.keras.layers.Reshape((const.num_patches, 2))(prediction)
if i is 0:
prediction = tf.keras.layers.Add()([prediction_reshape, y_init])
dxs.append(prediction)
else:
prediction = tf.keras.layers.Add()([prediction_reshape,dxs[i-1]])
dxs.append(prediction)
output = tf.keras.layers.Lambda(stack)(dxs)
model = tf.keras.models.Model(inputs=[x_input, y_init], outputs=[output])
def extract_patches(inputs):
list_patches = []
for j in range(const.num_patches):
patch_one = tf.image.extract_glimpse(inputs[0], [const.size_patch[0], const.size_patch[1]], inputs[1][:, j, :], centered=False, normalized=False, noise='zero')
list_patches.append(patch_one)
patches = tf.keras.backend.stack(list_patches,1)
return tf.keras.backend.reshape(patches,(-1,patches.shape[2],patches.shape[3],patches.shape[4]))
def reshape(inputs):
return tf.keras.backend.reshape(inputs,(-1,const.num_patches*inputs.shape[1]*inputs.shape[2]*inputs.shape[3]))
def stack(inputs):
return tf.keras.backend.stack(inputs)
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
x_input (InputLayer) [(None, 255, 235, 1) 0
__________________________________________________________________________________________________
y_init (InputLayer) [(None, 52, 2)] 0
__________________________________________________________________________________________________
lambda (Lambda) (None, 26, 26, 1) 0 x_input[0][0]
y_init[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 32) 320 lambda[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0 conv2d[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 13, 13, 32) 9248 max_pooling2d[0][0]
__________________________________________________________________________________________________
cropping2d (Cropping2D) (None, 6, 6, 32) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 6, 6, 32) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 6, 6, 64) 0 cropping2d[0][0]
max_pooling2d_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda) (None, 119808) 0 concatenate[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 512) 61342208 lambda_1[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 512) 2048 dense[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 104) 53352 batch_normalization[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 104) 0 dense_1[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (None, 52, 2) 0 dropout[0][0]
__________________________________________________________________________________________________
add (Add) (None, 52, 2) 0 reshape[0][0]
y_init[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda) (None, 26, 26, 1) 0 x_input[0][0]
add[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 26, 26, 32) 320 lambda_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 13, 13, 32) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 13, 13, 32) 9248 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
cropping2d_1 (Cropping2D) (None, 6, 6, 32) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 6, 6, 32) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 6, 6, 64) 0 cropping2d_1[0][0]
max_pooling2d_3[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda) (None, 119808) 0 concatenate_1[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 120320) 0 lambda_3[0][0]
batch_normalization[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 512) 61604352 concatenate_2[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 512) 2048 dense_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 104) 53352 batch_normalization_1[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 104) 0 dense_3[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 52, 2) 0 dropout_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 52, 2) 0 reshape_1[0][0]
add[0][0]
__________________________________________________________________________________________________
lambda_4 (Lambda) (None, 26, 26, 1) 0 x_input[0][0]
add_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 26, 26, 32) 320 lambda_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 13, 13, 32) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 13, 13, 32) 9248 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
cropping2d_2 (Cropping2D) (None, 6, 6, 32) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D) (None, 6, 6, 32) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 6, 6, 64) 0 cropping2d_2[0][0]
max_pooling2d_5[0][0]
__________________________________________________________________________________________________
lambda_5 (Lambda) (None, 119808) 0 concatenate_3[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 120320) 0 lambda_5[0][0]
batch_normalization_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 512) 61604352 concatenate_4[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 512) 2048 dense_4[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 104) 53352 batch_normalization_2[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 104) 0 dense_5[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 52, 2) 0 dropout_2[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 52, 2) 0 reshape_2[0][0]
add_1[0][0]
__________________________________________________________________________________________________
lambda_6 (Lambda) (None, 26, 26, 1) 0 x_input[0][0]
add_2[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 26, 26, 32) 320 lambda_6[0][0]
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D) (None, 13, 13, 32) 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 13, 13, 32) 9248 max_pooling2d_6[0][0]
__________________________________________________________________________________________________
cropping2d_3 (Cropping2D) (None, 6, 6, 32) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D) (None, 6, 6, 32) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
concatenate_5 (Concatenate) (None, 6, 6, 64) 0 cropping2d_3[0][0]
max_pooling2d_7[0][0]
__________________________________________________________________________________________________
lambda_7 (Lambda) (None, 119808) 0 concatenate_5[0][0]
__________________________________________________________________________________________________
concatenate_6 (Concatenate) (None, 120320) 0 lambda_7[0][0]
batch_normalization_2[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 512) 61604352 concatenate_6[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 512) 2048 dense_6[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 104) 53352 batch_normalization_3[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 104) 0 dense_7[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape) (None, 52, 2) 0 dropout_3[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 52, 2) 0 reshape_3[0][0]
add_2[0][0]
__________________________________________________________________________________________________
lambda_8 (Lambda) (4, None, 52, 2) 0 add[0][0]
add_1[0][0]
add_2[0][0]
add_3[0][0]
==================================================================================================
Total params: 246,415,136
Trainable params: 246,411,040
Non-trainable params: 4,096
答案 0 :(得分:1)
您必须减小型号的大小,因为对于该型号,1GB是合理的,但是有些解决方案这样做的目的是最终精度不会降低,在某些情况下还会增加。您可以搜索通过修剪来改善神经网络。