我正在使用tensorflow和keras构建图像语义分割的模型。我正在尝试建立一个PSPNet架构来做到这一点。
我主要基于以下来源建立我的体系结构: https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html https://developers.arcgis.com/python/guide/how-pspnet-works/
我正在使用的数据集是牛津猫和狗的图像数据集。对于每个图像,我都有一个对应的蒙版,如下所示:
图像已全部调整为288x288x1(1通道)。而且我正在使用keras的功能to_categorical对所有遮罩值进行一次热编码。
我遇到的问题如下:将生成的模型用于图像分割预测时,我得到了可怕的预测。我认为问题出在我下面发布的模型结构中。我曾尝试过对模型体系结构进行许多不同的更改,但并没有得到更好的结果。
我的模型结构:
# INPUT LAYER #
inputL = tf.keras.layers.Input((sourceWidth, sourceHeight, sourceChannels))
#inputL = tf.keras.layers.ResNet50()
# ENCODER #
conv1 = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputL)
conv1 = tf.keras.layers.Dropout(0.1)(conv1)
conv1 = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(conv1)
conv1 = tf.keras.layers.MaxPooling2D((2,2))(conv1)
conv2 = tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same')(conv1)
conv2 = tf.keras.layers.Dropout(0.1)(conv2)
conv2 = tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same')(conv2)
conv2 = tf.keras.layers.MaxPooling2D((2,2))(conv2)
conv3 = tf.keras.layers.Conv2D(128, (3,3), dilation_rate=(2,2), activation='relu', padding='same')(conv2)
conv3 = tf.keras.layers.Dropout(0.1)(conv3)
conv3 = tf.keras.layers.Conv2D(128, (3,3), dilation_rate=(2,2), activation='relu', padding='same')(conv3)
conv3 = tf.keras.layers.MaxPooling2D((2,2))(conv3)
conv4 = tf.keras.layers.Conv2D(256, (3,3), dilation_rate=(4,4), activation='relu', padding='same')(conv3)
conv4 = tf.keras.layers.Dropout(0.1)(conv4)
conv4 = tf.keras.layers.Conv2D(256, (3,3), dilation_rate=(4,4), activation='relu', padding='same')(conv4)
# PYRAMID POOLING MODULE
# Pyramid block 1 #
pyramid_block1 = tf.keras.layers.GlobalAveragePooling2D()(conv4)
pyramid_block1 = tf.keras.layers.Reshape((1,1,256))(pyramid_block1)
pyramid_block1 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block1)
pyramid_block1 = tf.keras.layers.Conv2DTranspose(64, (1,1), strides=(36,36), padding='same')(pyramid_block1)
# Pyramid block 2 #
pyramid_block2 = tf.keras.layers.AveragePooling2D(pool_size=(18,18), strides=(18,18))(conv4)
pyramid_block2 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block2)
pyramid_block2 = tf.keras.layers.Conv2DTranspose(64, (2,2), strides=(18,18), padding='same')(pyramid_block2)
# Pyramid block 3 #
pyramid_block3 = tf.keras.layers.AveragePooling2D(pool_size=(9,9), strides=(9,9))(conv4)
pyramid_block3 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block3)
pyramid_block3 = tf.keras.layers.Conv2DTranspose(64, (3,3), strides=(9,9), padding='same')(pyramid_block3)
# Pyramid block 4 #
pyramid_block4 = tf.keras.layers.AveragePooling2D(pool_size=(6,6), strides=(6,6))(conv4)
pyramid_block4 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block4)
pyramid_block4 = tf.keras.layers.Conv2DTranspose(64, (2,2), strides=(6,6), padding='same')(pyramid_block4)
# DECODER
pyramid = tf.keras.layers.concatenate([conv4, pyramid_block4, pyramid_block3, pyramid_block2, pyramid_block1])
conv7 = tf.keras.layers.Conv2D(64, (3,3), strides=(1,1), activation='relu', padding='same')(pyramid)
conv7 = tf.keras.layers.Dropout(0.2)(conv7)
conv7 = tf.keras.layers.Conv2D(64, (3,3), strides=(1,1), activation='relu', padding='same')(conv7)
conv8 = tf.keras.layers.Conv2DTranspose(64, (2,2), strides=(8,8), padding='same')(conv7)
conv8 = tf.keras.layers.Conv2D(32, (3,3), strides=(1,1), activation='relu', padding='same')(conv8)
conv8 = tf.keras.layers.Dropout(0.2)(conv8)
conv8 = tf.keras.layers.Conv2D(32, (3,3), strides=(1,1), activation='relu', padding='same')(conv8)
output = tf.keras.layers.Conv2D(num_classes, (1,1), activation='softmax')(conv8)
model = tf.keras.Model(inputs=[inputL], outputs=[output])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 288, 288, 1) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 288, 288, 32) 320 input_1[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 288, 288, 32) 0 conv2d[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 288, 288, 32) 9248 dropout[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 144, 144, 32) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 144, 144, 64) 18496 max_pooling2d[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 144, 144, 64) 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 144, 144, 64) 36928 dropout_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 72, 72, 64) 0 conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 72, 72, 128) 73856 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 72, 72, 128) 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 72, 72, 128) 147584 dropout_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 36, 36, 128) 0 conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 36, 36, 256) 295168 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 36, 36, 256) 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 36, 36, 256) 590080 dropout_3[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 256) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 6, 6, 256) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 4, 4, 256) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
average_pooling2d (AveragePooli (None, 2, 2, 256) 0 conv2d_7[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (None, 1, 1, 256) 0 global_average_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 6, 6, 64) 16448 average_pooling2d_2[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 4, 4, 64) 16448 average_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 2, 2, 64) 16448 average_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 1, 1, 64) 16448 reshape[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 36, 36, 64) 16448 conv2d_11[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 36, 36, 64) 36928 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 36, 36, 64) 16448 conv2d_9[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 36, 36, 64) 4160 conv2d_8[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 36, 36, 512) 0 conv2d_7[0][0]
conv2d_transpose_3[0][0]
conv2d_transpose_2[0][0]
conv2d_transpose_1[0][0]
conv2d_transpose[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 36, 36, 64) 294976 concatenate[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, 36, 36, 64) 0 conv2d_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 36, 36, 64) 36928 dropout_4[0][0]
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 288, 288, 64) 16448 conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 288, 288, 32) 18464 conv2d_transpose_4[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout) (None, 288, 288, 32) 0 conv2d_14[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 288, 288, 32) 9248 dropout_5[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 288, 288, 3) 99 conv2d_15[0][0]
==================================================================================================
Total params: 1,687,619
Trainable params: 1,687,619
Non-trainable params: 0