我正在尝试训练细分模型,但是损失饱和在0.3370,我真的不确定该怎么办,有人可以帮忙
这是模型
def unet(input_shape=(128, 128, 128), optimizer=Adam, initial_learning_rate=5e-4,
loss_function=weighted_dice_coefficient_loss):
inputs = Input(shape=input_shape)
conv1 = UnetConv3D(inputs, 32, is_batchnorm=False, name='conv1')
pool1 = MaxPooling3D(pool_size=(2, 2,2 ))(conv1)
conv2 = UnetConv3D(pool1, 64, is_batchnorm=False, name='conv2')
pool2 = MaxPooling3D(pool_size=(2, 2,2 ))(conv2)
conv3 = UnetConv3D(pool2, 128, is_batchnorm=False, name='conv3')
pool3 = MaxPooling3D(pool_size=(2, 2,2 ))(conv3)
conv4 = UnetConv3D(pool3, 256, is_batchnorm=False, name='conv4')
pool4 = MaxPooling3D(pool_size=(2, 2,2 ))(conv4)
conv5 = Conv3D(512, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(pool4)
conv5 = Conv3D(512, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv5)
up6 = concatenate([Conv3DTranspose(256, (2, 2,2 ), strides=(2, 2,2 ), kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv5), conv4], axis=1)
conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same', data_format = 'channels_first')(up6)
conv6 = Conv3D(256, (3, 3, 3), activation='relu', padding='same', data_format = 'channels_first')(conv6)
up7 = concatenate([Conv3DTranspose(128, (2, 2,2 ), strides=(2, 2,2 ), padding='same', data_format = 'channels_first')(conv6), conv3], axis=1)
conv7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(up7)
conv7 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv7)
up8 = concatenate([Conv3DTranspose(64, (2, 2,2 ), strides=(2,2,2 ), kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv7), conv2], axis=1)
conv8 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(up8)
up9 = concatenate([Conv3DTranspose(32, (2, 2,2 ), strides=(2, 2,2 ), kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv8), conv1], axis=1)
conv9 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(up9)
conv9 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer=kinit, padding='same', data_format = 'channels_first')(conv9)
conv10 = Conv3D(3, (1, 1, 1), activation='relu', kernel_initializer=kinit,padding = 'same', name='final', data_format = 'channels_first')(conv9)
activation_name = 'sigmoid'
activation_block = Activation(activation_name)(conv10)
model = Model(inputs=[inputs], outputs=[activation_block])
model.compile(optimizer=optimizer(), loss=loss_function)
return model
这是辅助功能
def UnetConv3D(input, outdim, is_batchnorm, name):
x = Conv3D(outdim, (3, 3, 3), strides=(1, 1, 1), kernel_initializer=kinit, padding="same", name=name+'_1', data_format = 'channels_first')(input)
if is_batchnorm:
x =BatchNormalization(name=name + '_1_bn')(x)
x = Activation('relu',name=name + '_1_act')(x)
x = Conv3D(outdim, (3, 3, 3), strides=(1, 1, 1), kernel_initializer=kinit, padding="same", name=name+'_2', data_format = 'channels_first')(x)
if is_batchnorm:
x = BatchNormalization(name=name + '_2_bn')(x)
x = Activation('relu', name=name + '_2_act')(x)
return x
这是损失函数-
def weighted_dice_coefficient(y_true, y_pred, axis=(-3, -2, -1), smooth=0.00001):
"""
Weighted dice coefficient. Default axis assumes a "channels first" data structure
:param smooth:
:param y_true:
:param y_pred:
:param axis:
:return:
"""
return K.mean(2. * (K.sum(y_true * y_pred,
axis=axis) + smooth/2)/(K.sum(y_true,
axis=axis) + K.sum(y_pred,
axis=axis) + smooth))
我的输入是(128,128,128),我犯了一个明显的错误吗?请让我知道是否需要更多信息。
模型摘要
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 1, 128, 128, 0
__________________________________________________________________________________________________
conv1_1 (Conv3D) (None, 32, 128, 128, 896 input_1[0][0]
__________________________________________________________________________________________________
conv1_1_act (Activation) (None, 32, 128, 128, 0 conv1_1[0][0]
__________________________________________________________________________________________________
conv1_2 (Conv3D) (None, 32, 128, 128, 27680 conv1_1_act[0][0]
__________________________________________________________________________________________________
conv1_2_act (Activation) (None, 32, 128, 128, 0 conv1_2[0][0]
__________________________________________________________________________________________________
max_pooling3d_1 (MaxPooling3D) (None, 32, 64, 64, 6 0 conv1_2_act[0][0]
__________________________________________________________________________________________________
conv2_1 (Conv3D) (None, 64, 64, 64, 6 55360 max_pooling3d_1[0][0]
__________________________________________________________________________________________________
conv2_1_act (Activation) (None, 64, 64, 64, 6 0 conv2_1[0][0]
__________________________________________________________________________________________________
conv2_2 (Conv3D) (None, 64, 64, 64, 6 110656 conv2_1_act[0][0]
__________________________________________________________________________________________________
conv2_2_act (Activation) (None, 64, 64, 64, 6 0 conv2_2[0][0]
__________________________________________________________________________________________________
max_pooling3d_2 (MaxPooling3D) (None, 64, 32, 32, 3 0 conv2_2_act[0][0]
__________________________________________________________________________________________________
conv3_1 (Conv3D) (None, 128, 32, 32, 221312 max_pooling3d_2[0][0]
__________________________________________________________________________________________________
conv3_1_act (Activation) (None, 128, 32, 32, 0 conv3_1[0][0]
__________________________________________________________________________________________________
conv3_2 (Conv3D) (None, 128, 32, 32, 442496 conv3_1_act[0][0]
__________________________________________________________________________________________________
conv3_2_act (Activation) (None, 128, 32, 32, 0 conv3_2[0][0]
__________________________________________________________________________________________________
max_pooling3d_3 (MaxPooling3D) (None, 128, 16, 16, 0 conv3_2_act[0][0]
__________________________________________________________________________________________________
conv4_1 (Conv3D) (None, 256, 16, 16, 884992 max_pooling3d_3[0][0]
__________________________________________________________________________________________________
conv4_1_act (Activation) (None, 256, 16, 16, 0 conv4_1[0][0]
__________________________________________________________________________________________________
conv4_2 (Conv3D) (None, 256, 16, 16, 1769728 conv4_1_act[0][0]
__________________________________________________________________________________________________
conv4_2_act (Activation) (None, 256, 16, 16, 0 conv4_2[0][0]
__________________________________________________________________________________________________
max_pooling3d_4 (MaxPooling3D) (None, 256, 8, 8, 8) 0 conv4_2_act[0][0]
__________________________________________________________________________________________________
conv3d_1 (Conv3D) (None, 512, 8, 8, 8) 3539456 max_pooling3d_4[0][0]
__________________________________________________________________________________________________
conv3d_2 (Conv3D) (None, 512, 8, 8, 8) 7078400 conv3d_1[0][0]
__________________________________________________________________________________________________
conv3d_transpose_1 (Conv3DTrans (None, 256, 16, 16, 1048832 conv3d_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 512, 16, 16, 0 conv3d_transpose_1[0][0]
conv4_2_act[0][0]
__________________________________________________________________________________________________
conv3d_3 (Conv3D) (None, 256, 16, 16, 3539200 concatenate_1[0][0]
__________________________________________________________________________________________________
conv3d_4 (Conv3D) (None, 256, 16, 16, 1769728 conv3d_3[0][0]
__________________________________________________________________________________________________
conv3d_transpose_2 (Conv3DTrans (None, 128, 32, 32, 262272 conv3d_4[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 256, 32, 32, 0 conv3d_transpose_2[0][0]
conv3_2_act[0][0]
__________________________________________________________________________________________________
conv3d_5 (Conv3D) (None, 128, 32, 32, 884864 concatenate_2[0][0]
__________________________________________________________________________________________________
conv3d_6 (Conv3D) (None, 128, 32, 32, 442496 conv3d_5[0][0]
__________________________________________________________________________________________________
conv3d_transpose_3 (Conv3DTrans (None, 64, 64, 64, 6 65600 conv3d_6[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 128, 64, 64, 0 conv3d_transpose_3[0][0]
conv2_2_act[0][0]
__________________________________________________________________________________________________
conv3d_7 (Conv3D) (None, 64, 64, 64, 6 221248 concatenate_3[0][0]
__________________________________________________________________________________________________
conv3d_transpose_4 (Conv3DTrans (None, 32, 128, 128, 16416 conv3d_7[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 64, 128, 128, 0 conv3d_transpose_4[0][0]
conv1_2_act[0][0]
__________________________________________________________________________________________________
conv3d_8 (Conv3D) (None, 32, 128, 128, 55328 concatenate_4[0][0]
__________________________________________________________________________________________________
conv3d_9 (Conv3D) (None, 32, 128, 128, 27680 conv3d_8[0][0]
__________________________________________________________________________________________________
final (Conv3D) (None, 3, 128, 128, 99 conv3d_9[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 3, 128, 128, 0 final[0][0]
==================================================================================================
预先感谢