使用PSP-Net进行图像分割

时间:2020-08-10 08:38:07

标签: python tensorflow keras conv-neural-network image-segmentation

我正在使用tensorflow和keras构建图像语义分割的模型。我正在尝试建立一个PSPNet架构来做到这一点。

我主要基于以下来源建立我的体系结构: https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html https://developers.arcgis.com/python/guide/how-pspnet-works/

我正在使用的数据集是牛津猫和狗的图像数据集。对于每个图像,我都有一个对应的蒙版,如下所示:

  • 背景像素值0
  • 显示猫的像素的像素值1
  • 显示狗的像素的像素值2

图像已全部调整为288x288x1(1通道)。而且我正在使用keras的功能to_categorical对所有遮罩值进行一次热编码。

我遇到的问题如下:将生成的模型用于图像分割预测时,我得到了可怕的预测。我认为问题出在我下面发布的模型结构中。我曾尝试过对模型体系结构进行许多不同的更改,但并没有得到更好的结果。

我的模型结构:

# INPUT LAYER #
inputL = tf.keras.layers.Input((sourceWidth, sourceHeight, sourceChannels))

#inputL = tf.keras.layers.ResNet50()

# ENCODER #
conv1 = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(inputL)
conv1 = tf.keras.layers.Dropout(0.1)(conv1)
conv1 = tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same')(conv1)
conv1 = tf.keras.layers.MaxPooling2D((2,2))(conv1)

conv2 = tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same')(conv1)
conv2 = tf.keras.layers.Dropout(0.1)(conv2)
conv2 = tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same')(conv2)
conv2 = tf.keras.layers.MaxPooling2D((2,2))(conv2)

conv3 = tf.keras.layers.Conv2D(128, (3,3), dilation_rate=(2,2), activation='relu', padding='same')(conv2)
conv3 = tf.keras.layers.Dropout(0.1)(conv3)
conv3 = tf.keras.layers.Conv2D(128, (3,3), dilation_rate=(2,2), activation='relu', padding='same')(conv3)
conv3 = tf.keras.layers.MaxPooling2D((2,2))(conv3)

conv4 = tf.keras.layers.Conv2D(256, (3,3), dilation_rate=(4,4), activation='relu', padding='same')(conv3)
conv4 = tf.keras.layers.Dropout(0.1)(conv4)
conv4 = tf.keras.layers.Conv2D(256, (3,3), dilation_rate=(4,4), activation='relu', padding='same')(conv4)


# PYRAMID POOLING MODULE
# Pyramid block 1 #
pyramid_block1 = tf.keras.layers.GlobalAveragePooling2D()(conv4)
pyramid_block1 = tf.keras.layers.Reshape((1,1,256))(pyramid_block1)
pyramid_block1 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block1)
pyramid_block1 = tf.keras.layers.Conv2DTranspose(64, (1,1), strides=(36,36), padding='same')(pyramid_block1)

# Pyramid block 2 #
pyramid_block2 = tf.keras.layers.AveragePooling2D(pool_size=(18,18), strides=(18,18))(conv4)
pyramid_block2 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block2)
pyramid_block2 = tf.keras.layers.Conv2DTranspose(64, (2,2), strides=(18,18), padding='same')(pyramid_block2)

# Pyramid block 3 #
pyramid_block3 = tf.keras.layers.AveragePooling2D(pool_size=(9,9), strides=(9,9))(conv4)
pyramid_block3 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block3)
pyramid_block3 = tf.keras.layers.Conv2DTranspose(64, (3,3), strides=(9,9), padding='same')(pyramid_block3)

# Pyramid block 4 #
pyramid_block4 = tf.keras.layers.AveragePooling2D(pool_size=(6,6), strides=(6,6))(conv4)
pyramid_block4 = tf.keras.layers.Conv2D(64, (1,1), strides=(1,1), activation='relu')(pyramid_block4)
pyramid_block4 = tf.keras.layers.Conv2DTranspose(64, (2,2), strides=(6,6), padding='same')(pyramid_block4)


# DECODER
pyramid = tf.keras.layers.concatenate([conv4, pyramid_block4, pyramid_block3, pyramid_block2, pyramid_block1])
conv7 = tf.keras.layers.Conv2D(64, (3,3), strides=(1,1), activation='relu', padding='same')(pyramid)
conv7 = tf.keras.layers.Dropout(0.2)(conv7)
conv7 = tf.keras.layers.Conv2D(64, (3,3), strides=(1,1), activation='relu', padding='same')(conv7)

conv8 = tf.keras.layers.Conv2DTranspose(64, (2,2), strides=(8,8), padding='same')(conv7)
conv8 = tf.keras.layers.Conv2D(32, (3,3), strides=(1,1), activation='relu', padding='same')(conv8)
conv8 = tf.keras.layers.Dropout(0.2)(conv8)
conv8 = tf.keras.layers.Conv2D(32, (3,3), strides=(1,1), activation='relu', padding='same')(conv8)
output = tf.keras.layers.Conv2D(num_classes, (1,1), activation='softmax')(conv8)

model = tf.keras.Model(inputs=[inputL], outputs=[output])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

enter image description here

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 288, 288, 1)  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 288, 288, 32) 320         input_1[0][0]                    
__________________________________________________________________________________________________
dropout (Dropout)               (None, 288, 288, 32) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 288, 288, 32) 9248        dropout[0][0]                    
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 144, 144, 32) 0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 144, 144, 64) 18496       max_pooling2d[0][0]              
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 144, 144, 64) 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 144, 144, 64) 36928       dropout_1[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 72, 72, 64)   0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 72, 72, 128)  73856       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 72, 72, 128)  0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 72, 72, 128)  147584      dropout_2[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 36, 36, 128)  0           conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 36, 36, 256)  295168      max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 36, 36, 256)  0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 36, 36, 256)  590080      dropout_3[0][0]                  
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 256)          0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 6, 6, 256)    0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 4, 4, 256)    0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
average_pooling2d (AveragePooli (None, 2, 2, 256)    0           conv2d_7[0][0]                   
__________________________________________________________________________________________________
reshape (Reshape)               (None, 1, 1, 256)    0           global_average_pooling2d[0][0]   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 6, 6, 64)     16448       average_pooling2d_2[0][0]        
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 4, 4, 64)     16448       average_pooling2d_1[0][0]        
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 2, 2, 64)     16448       average_pooling2d[0][0]          
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 1, 1, 64)     16448       reshape[0][0]                    
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 36, 36, 64)   16448       conv2d_11[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 36, 36, 64)   36928       conv2d_10[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 36, 36, 64)   16448       conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 36, 36, 64)   4160        conv2d_8[0][0]                   
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 36, 36, 512)  0           conv2d_7[0][0]                   
                                                                 conv2d_transpose_3[0][0]         
                                                                 conv2d_transpose_2[0][0]         
                                                                 conv2d_transpose_1[0][0]         
                                                                 conv2d_transpose[0][0]           
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 36, 36, 64)   294976      concatenate[0][0]                
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 36, 36, 64)   0           conv2d_12[0][0]                  
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 36, 36, 64)   36928       dropout_4[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 288, 288, 64) 16448       conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 288, 288, 32) 18464       conv2d_transpose_4[0][0]         
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 288, 288, 32) 0           conv2d_14[0][0]                  
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 288, 288, 32) 9248        dropout_5[0][0]                  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 288, 288, 3)  99          conv2d_15[0][0]                  
==================================================================================================
Total params: 1,687,619
Trainable params: 1,687,619
Non-trainable params: 0

0 个答案:

没有答案