Unet:多类别图像分割

时间:2019-12-15 11:57:11

标签: keras deep-learning image-segmentation

我最近开始学习图像分割和UNet。我试图做一个多类图像分割,其中我有7个类,输入是(256,256,3)rgb图像,输出是(256,256,1)灰度图像,其中每个强度值对应一个类。我正在做像素明智的softmax。我正在使用稀疏分类交叉熵以避免执行“一次热编码”。

def soft1(x):
    return keras.activations.softmax(x, axis = -1)

def conv2d_block(input_tensor, n_filters, kernel_size = 3, batchnorm = True):

    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)


    x = Conv2D(filters = n_filters, kernel_size = (kernel_size, kernel_size),\
              kernel_initializer = 'he_normal', padding = 'same')(input_tensor)
    if batchnorm:
        x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x

def get_unet(input_img, n_classes, n_filters = 16, dropout = 0.1, batchnorm = True):
    # Contracting Path
    c1 = conv2d_block(input_img, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)
    p1 = MaxPooling2D((2, 2))(c1)
    p1 = Dropout(dropout)(p1)

    c2 = conv2d_block(p1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)
    p2 = MaxPooling2D((2, 2))(c2)
    p2 = Dropout(dropout)(p2)

    c3 = conv2d_block(p2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)
    p3 = MaxPooling2D((2, 2))(c3)
    p3 = Dropout(dropout)(p3)

    c4 = conv2d_block(p3, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)
    p4 = MaxPooling2D((2, 2))(c4)
    p4 = Dropout(dropout)(p4)

    c5 = conv2d_block(p4, n_filters = n_filters * 16, kernel_size = 3, batchnorm = batchnorm)

    # Expansive Path
    u6 = Conv2DTranspose(n_filters * 8, (3, 3), strides = (2, 2), padding = 'same')(c5)
    u6 = concatenate([u6, c4])
    u6 = Dropout(dropout)(u6)
    c6 = conv2d_block(u6, n_filters * 8, kernel_size = 3, batchnorm = batchnorm)

    u7 = Conv2DTranspose(n_filters * 4, (3, 3), strides = (2, 2), padding = 'same')(c6)
    u7 = concatenate([u7, c3])
    u7 = Dropout(dropout)(u7)
    c7 = conv2d_block(u7, n_filters * 4, kernel_size = 3, batchnorm = batchnorm)

    u8 = Conv2DTranspose(n_filters * 2, (3, 3), strides = (2, 2), padding = 'same')(c7)
    u8 = concatenate([u8, c2])
    u8 = Dropout(dropout)(u8)
    c8 = conv2d_block(u8, n_filters * 2, kernel_size = 3, batchnorm = batchnorm)

    u9 = Conv2DTranspose(n_filters * 1, (3, 3), strides = (2, 2), padding = 'same')(c8)
    u9 = concatenate([u9, c1])
    u9 = Dropout(dropout)(u9)
    c9 = conv2d_block(u9, n_filters * 1, kernel_size = 3, batchnorm = batchnorm)

    outputs = Conv2D(n_classes, (1, 1))(c9)
    outputs = Reshape((image_height*image_width, 1, n_classes), input_shape = (image_height, image_width, n_classes))(outputs)
    outputs = Activation(soft1)(outputs)

    model = Model(inputs=[input_img], outputs=[outputs])
    print(outputs.shape)

    return model

我的模型摘要是:

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_12 (InputLayer)           (None, 256, 256, 3)  0                                            
__________________________________________________________________________________________________
conv2d_211 (Conv2D)             (None, 256, 256, 16) 448         input_12[0][0]                   
__________________________________________________________________________________________________
batch_normalization_200 (BatchN (None, 256, 256, 16) 64          conv2d_211[0][0]                 
__________________________________________________________________________________________________
activation_204 (Activation)     (None, 256, 256, 16) 0           batch_normalization_200[0][0]    
__________________________________________________________________________________________________
max_pooling2d_45 (MaxPooling2D) (None, 128, 128, 16) 0           activation_204[0][0]             
__________________________________________________________________________________________________
dropout_89 (Dropout)            (None, 128, 128, 16) 0           max_pooling2d_45[0][0]           
__________________________________________________________________________________________________
conv2d_213 (Conv2D)             (None, 128, 128, 32) 4640        dropout_89[0][0]                 
__________________________________________________________________________________________________
batch_normalization_202 (BatchN (None, 128, 128, 32) 128         conv2d_213[0][0]                 
__________________________________________________________________________________________________
activation_206 (Activation)     (None, 128, 128, 32) 0           batch_normalization_202[0][0]    
__________________________________________________________________________________________________
max_pooling2d_46 (MaxPooling2D) (None, 64, 64, 32)   0           activation_206[0][0]             
__________________________________________________________________________________________________
dropout_90 (Dropout)            (None, 64, 64, 32)   0           max_pooling2d_46[0][0]           
__________________________________________________________________________________________________
conv2d_215 (Conv2D)             (None, 64, 64, 64)   18496       dropout_90[0][0]                 
__________________________________________________________________________________________________
batch_normalization_204 (BatchN (None, 64, 64, 64)   256         conv2d_215[0][0]                 
__________________________________________________________________________________________________
activation_208 (Activation)     (None, 64, 64, 64)   0           batch_normalization_204[0][0]    
__________________________________________________________________________________________________
max_pooling2d_47 (MaxPooling2D) (None, 32, 32, 64)   0           activation_208[0][0]             
__________________________________________________________________________________________________
dropout_91 (Dropout)            (None, 32, 32, 64)   0           max_pooling2d_47[0][0]           
__________________________________________________________________________________________________
conv2d_217 (Conv2D)             (None, 32, 32, 128)  73856       dropout_91[0][0]                 
__________________________________________________________________________________________________
batch_normalization_206 (BatchN (None, 32, 32, 128)  512         conv2d_217[0][0]                 
__________________________________________________________________________________________________
activation_210 (Activation)     (None, 32, 32, 128)  0           batch_normalization_206[0][0]    
__________________________________________________________________________________________________
max_pooling2d_48 (MaxPooling2D) (None, 16, 16, 128)  0           activation_210[0][0]             
__________________________________________________________________________________________________
dropout_92 (Dropout)            (None, 16, 16, 128)  0           max_pooling2d_48[0][0]           
__________________________________________________________________________________________________
conv2d_219 (Conv2D)             (None, 16, 16, 256)  295168      dropout_92[0][0]                 
__________________________________________________________________________________________________
batch_normalization_208 (BatchN (None, 16, 16, 256)  1024        conv2d_219[0][0]                 
__________________________________________________________________________________________________
activation_212 (Activation)     (None, 16, 16, 256)  0           batch_normalization_208[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_45 (Conv2DTran (None, 32, 32, 128)  295040      activation_212[0][0]             
__________________________________________________________________________________________________
concatenate_45 (Concatenate)    (None, 32, 32, 256)  0           conv2d_transpose_45[0][0]        
                                                                 activation_210[0][0]             
__________________________________________________________________________________________________
dropout_93 (Dropout)            (None, 32, 32, 256)  0           concatenate_45[0][0]             
__________________________________________________________________________________________________
conv2d_221 (Conv2D)             (None, 32, 32, 128)  295040      dropout_93[0][0]                 
__________________________________________________________________________________________________
batch_normalization_210 (BatchN (None, 32, 32, 128)  512         conv2d_221[0][0]                 
__________________________________________________________________________________________________
activation_214 (Activation)     (None, 32, 32, 128)  0           batch_normalization_210[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_46 (Conv2DTran (None, 64, 64, 64)   73792       activation_214[0][0]             
__________________________________________________________________________________________________
concatenate_46 (Concatenate)    (None, 64, 64, 128)  0           conv2d_transpose_46[0][0]        
                                                                 activation_208[0][0]             
__________________________________________________________________________________________________
dropout_94 (Dropout)            (None, 64, 64, 128)  0           concatenate_46[0][0]             
__________________________________________________________________________________________________
conv2d_223 (Conv2D)             (None, 64, 64, 64)   73792       dropout_94[0][0]                 
__________________________________________________________________________________________________
batch_normalization_212 (BatchN (None, 64, 64, 64)   256         conv2d_223[0][0]                 
__________________________________________________________________________________________________
activation_216 (Activation)     (None, 64, 64, 64)   0           batch_normalization_212[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_47 (Conv2DTran (None, 128, 128, 32) 18464       activation_216[0][0]             
__________________________________________________________________________________________________
concatenate_47 (Concatenate)    (None, 128, 128, 64) 0           conv2d_transpose_47[0][0]        
                                                                 activation_206[0][0]             
__________________________________________________________________________________________________
dropout_95 (Dropout)            (None, 128, 128, 64) 0           concatenate_47[0][0]             
__________________________________________________________________________________________________
conv2d_225 (Conv2D)             (None, 128, 128, 32) 18464       dropout_95[0][0]                 
__________________________________________________________________________________________________
batch_normalization_214 (BatchN (None, 128, 128, 32) 128         conv2d_225[0][0]                 
__________________________________________________________________________________________________
activation_218 (Activation)     (None, 128, 128, 32) 0           batch_normalization_214[0][0]    
__________________________________________________________________________________________________
conv2d_transpose_48 (Conv2DTran (None, 256, 256, 16) 4624        activation_218[0][0]             
__________________________________________________________________________________________________
concatenate_48 (Concatenate)    (None, 256, 256, 32) 0           conv2d_transpose_48[0][0]        
                                                                 activation_204[0][0]             
__________________________________________________________________________________________________
dropout_96 (Dropout)            (None, 256, 256, 32) 0           concatenate_48[0][0]             
__________________________________________________________________________________________________
conv2d_227 (Conv2D)             (None, 256, 256, 16) 4624        dropout_96[0][0]                 
__________________________________________________________________________________________________
batch_normalization_216 (BatchN (None, 256, 256, 16) 64          conv2d_227[0][0]                 
__________________________________________________________________________________________________
activation_220 (Activation)     (None, 256, 256, 16) 0           batch_normalization_216[0][0]    
__________________________________________________________________________________________________
conv2d_228 (Conv2D)             (None, 256, 256, 7)  119         activation_220[0][0]             
__________________________________________________________________________________________________
reshape_12 (Reshape)            (None, 65536, 1, 7)  0           conv2d_228[0][0]                 
__________________________________________________________________________________________________
activation_221 (Activation)     (None, 65536, 1, 7)  0           reshape_12[0][0]                 
==================================================================================================
Total params: 1,179,511
Trainable params: 1,178,039
Non-trainable params: 1,472
__________________________________________________________________________________________________

我的模特对吗?我使用softmax时,最终输出不应为(65536,1,1)吗? 代码正在编译,但是骰子系数非常低。

2 个答案:

答案 0 :(得分:1)

您的输出实际上代表了图像的像素。对于其像素,输出为1x7。由于它是S形的,因此此表示形式采用的值在0-1之间。因此,当您具有所需的类并因此进行细分时,将触发输出。如果它是(65536, 1, 1),则应该不是分类而是密集的表示。

答案 1 :(得分:1)

您的模型应以(256,256,7)结尾。

即每个像素 7个类,并且形状应与(256,256,1)的输出图像一致。这仅适用于'sparse_categorical_crossentropy'或自定义损失。

因此,最多conv_228的模型看起来还不错(不过看起来并不详细)。
卷积之后不需要任何东西。

您可以将softmax直接放在conv_228中或之后。

y_train应该是(256,256,1)