训练一个unet模型,但是模型没有学习

时间:2019-06-22 11:34:20

标签: keras

我正在尝试训练细分模型,但是损失饱和在0.3370,我真的不确定该怎么办,有人可以帮忙

这是模型

def unet(input_shape=(128, 128, 128), optimizer=Adam, initial_learning_rate=5e-4,
                      loss_function=weighted_dice_coefficient_loss):   

    inputs = Input(shape=input_shape)
    conv1 = UnetConv3D(inputs, 32, is_batchnorm=False, name='conv1')
    pool1 = MaxPooling3D(pool_size=(2, 2,2 ))(conv1)

    conv2 = UnetConv3D(pool1, 64, is_batchnorm=False, name='conv2')
    pool2 = MaxPooling3D(pool_size=(2, 2,2 ))(conv2)

    conv3 = UnetConv3D(pool2, 128, is_batchnorm=False, name='conv3')
    pool3 = MaxPooling3D(pool_size=(2, 2,2 ))(conv3)

    conv4 = UnetConv3D(pool3, 256, is_batchnorm=False, name='conv4')
    pool4 = MaxPooling3D(pool_size=(2, 2,2 ))(conv4)

    conv5 = Conv3D(512, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(pool4)
    conv5 = Conv3D(512, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv5)

    up6 = concatenate([Conv3DTranspose(256, (2, 2,2 ), strides=(2, 2,2 ), kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv5), conv4], axis=1)
    conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same', data_format = 'channels_first')(up6)
    conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same', data_format = 'channels_first')(conv6)

    up7 = concatenate([Conv3DTranspose(128, (2, 2,2 ), strides=(2, 2,2 ), padding='same', data_format = 'channels_first')(conv6), conv3], axis=1)
    conv7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(up7)
    conv7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv7)

    up8 = concatenate([Conv3DTranspose(64, (2, 2,2 ), strides=(2,2,2 ), kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv7), conv2], axis=1)
    conv8 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(up8)

    up9    = concatenate([Conv3DTranspose(32, (2, 2,2 ), strides=(2, 2,2 ), kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv8), conv1], axis=1)
    conv9  = Conv3D(32, (3, 3, 3), activation='relu',  kernel_initializer=kinit, padding='same', data_format = 'channels_first')(up9)
    conv9  = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv9)
    conv10 = Conv3D(3, (1, 1, 1), activation='relu', kernel_initializer=kinit,padding = 'same', name='final', data_format = 'channels_first')(conv9)

    activation_name = 'sigmoid'
    activation_block = Activation(activation_name)(conv10)
    model = Model(inputs=[inputs], outputs=[activation_block])
    model.compile(optimizer=optimizer(), loss=loss_function)
    return model

这是辅助功能

def UnetConv3D(input, outdim, is_batchnorm, name):
    x = Conv3D(outdim, (3, 3, 3), strides=(1, 1, 1), kernel_initializer=kinit, padding="same", name=name+'_1', data_format = 'channels_first')(input)
    if is_batchnorm:
        x =BatchNormalization(name=name + '_1_bn')(x)
    x = Activation('relu',name=name + '_1_act')(x)

    x = Conv3D(outdim, (3, 3, 3), strides=(1, 1, 1), kernel_initializer=kinit, padding="same", name=name+'_2', data_format = 'channels_first')(x)
    if is_batchnorm:
        x = BatchNormalization(name=name + '_2_bn')(x)
    x = Activation('relu', name=name + '_2_act')(x)
    return x

这是损失函数-

def weighted_dice_coefficient(y_true, y_pred, axis=(-3, -2, -1), smooth=0.00001):
    """
    Weighted dice coefficient. Default axis assumes a "channels first" data structure
    :param smooth:
    :param y_true:
    :param y_pred:
    :param axis:
    :return:
    """
    return K.mean(2. * (K.sum(y_true * y_pred,
                              axis=axis) + smooth/2)/(K.sum(y_true,
                                                            axis=axis) + K.sum(y_pred,
                                                                               axis=axis) + smooth))

我的输入是(128,128,128),我犯了一个明显的错误吗?请让我知道是否需要更多信息。

模型摘要


Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 1, 128, 128,  0                                            
__________________________________________________________________________________________________
conv1_1 (Conv3D)                (None, 32, 128, 128, 896         input_1[0][0]                    
__________________________________________________________________________________________________
conv1_1_act (Activation)        (None, 32, 128, 128, 0           conv1_1[0][0]                    
__________________________________________________________________________________________________
conv1_2 (Conv3D)                (None, 32, 128, 128, 27680       conv1_1_act[0][0]                
__________________________________________________________________________________________________
conv1_2_act (Activation)        (None, 32, 128, 128, 0           conv1_2[0][0]                    
__________________________________________________________________________________________________
max_pooling3d_1 (MaxPooling3D)  (None, 32, 64, 64, 6 0           conv1_2_act[0][0]                
__________________________________________________________________________________________________
conv2_1 (Conv3D)                (None, 64, 64, 64, 6 55360       max_pooling3d_1[0][0]            
__________________________________________________________________________________________________
conv2_1_act (Activation)        (None, 64, 64, 64, 6 0           conv2_1[0][0]                    
__________________________________________________________________________________________________
conv2_2 (Conv3D)                (None, 64, 64, 64, 6 110656      conv2_1_act[0][0]                
__________________________________________________________________________________________________
conv2_2_act (Activation)        (None, 64, 64, 64, 6 0           conv2_2[0][0]                    
__________________________________________________________________________________________________
max_pooling3d_2 (MaxPooling3D)  (None, 64, 32, 32, 3 0           conv2_2_act[0][0]                
__________________________________________________________________________________________________
conv3_1 (Conv3D)                (None, 128, 32, 32,  221312      max_pooling3d_2[0][0]            
__________________________________________________________________________________________________
conv3_1_act (Activation)        (None, 128, 32, 32,  0           conv3_1[0][0]                    
__________________________________________________________________________________________________
conv3_2 (Conv3D)                (None, 128, 32, 32,  442496      conv3_1_act[0][0]                
__________________________________________________________________________________________________
conv3_2_act (Activation)        (None, 128, 32, 32,  0           conv3_2[0][0]                    
__________________________________________________________________________________________________
max_pooling3d_3 (MaxPooling3D)  (None, 128, 16, 16,  0           conv3_2_act[0][0]                
__________________________________________________________________________________________________
conv4_1 (Conv3D)                (None, 256, 16, 16,  884992      max_pooling3d_3[0][0]            
__________________________________________________________________________________________________
conv4_1_act (Activation)        (None, 256, 16, 16,  0           conv4_1[0][0]                    
__________________________________________________________________________________________________
conv4_2 (Conv3D)                (None, 256, 16, 16,  1769728     conv4_1_act[0][0]                
__________________________________________________________________________________________________
conv4_2_act (Activation)        (None, 256, 16, 16,  0           conv4_2[0][0]                    
__________________________________________________________________________________________________
max_pooling3d_4 (MaxPooling3D)  (None, 256, 8, 8, 8) 0           conv4_2_act[0][0]                
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, 512, 8, 8, 8) 3539456     max_pooling3d_4[0][0]            
__________________________________________________________________________________________________
conv3d_2 (Conv3D)               (None, 512, 8, 8, 8) 7078400     conv3d_1[0][0]                   
__________________________________________________________________________________________________
conv3d_transpose_1 (Conv3DTrans (None, 256, 16, 16,  1048832     conv3d_2[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 512, 16, 16,  0           conv3d_transpose_1[0][0]         
                                                                 conv4_2_act[0][0]                
__________________________________________________________________________________________________
conv3d_3 (Conv3D)               (None, 256, 16, 16,  3539200     concatenate_1[0][0]              
__________________________________________________________________________________________________
conv3d_4 (Conv3D)               (None, 256, 16, 16,  1769728     conv3d_3[0][0]                   
__________________________________________________________________________________________________
conv3d_transpose_2 (Conv3DTrans (None, 128, 32, 32,  262272      conv3d_4[0][0]                   
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 256, 32, 32,  0           conv3d_transpose_2[0][0]         
                                                                 conv3_2_act[0][0]                
__________________________________________________________________________________________________
conv3d_5 (Conv3D)               (None, 128, 32, 32,  884864      concatenate_2[0][0]              
__________________________________________________________________________________________________
conv3d_6 (Conv3D)               (None, 128, 32, 32,  442496      conv3d_5[0][0]                   
__________________________________________________________________________________________________
conv3d_transpose_3 (Conv3DTrans (None, 64, 64, 64, 6 65600       conv3d_6[0][0]                   
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 128, 64, 64,  0           conv3d_transpose_3[0][0]         
                                                                 conv2_2_act[0][0]                
__________________________________________________________________________________________________
conv3d_7 (Conv3D)               (None, 64, 64, 64, 6 221248      concatenate_3[0][0]              
__________________________________________________________________________________________________
conv3d_transpose_4 (Conv3DTrans (None, 32, 128, 128, 16416       conv3d_7[0][0]                   
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 64, 128, 128, 0           conv3d_transpose_4[0][0]         
                                                                 conv1_2_act[0][0]                
__________________________________________________________________________________________________
conv3d_8 (Conv3D)               (None, 32, 128, 128, 55328       concatenate_4[0][0]              
__________________________________________________________________________________________________
conv3d_9 (Conv3D)               (None, 32, 128, 128, 27680       conv3d_8[0][0]                   
__________________________________________________________________________________________________
final (Conv3D)                  (None, 3, 128, 128,  99          conv3d_9[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 3, 128, 128,  0           final[0][0]                      
==================================================================================================

预先感谢

0 个答案:

没有答案