Unet性能不佳的可能原因是什么?

时间:2020-04-02 13:19:46

标签: python-3.x keras image-segmentation

我正在尝试在Tensorflow 2.0(Keras模块)上使用Unet处理二进制分割问题。 我的班级非常不平衡,因此我必须使用班级权重(背景为0.03,前景为1.0)。训练集中有约2500个样本,验证集中有约250个样本。

数据样本(图像及其遮罩):

enter image description here enter image description here

度量标准是联合的交集。丢失功能是Jaccard丢失。 经过约10个训练周期后,该过程将停止在EarlyStopping上。损耗非常高,指标非常低。我试图降低学习率,但是并没有太大帮助。

当我尝试使用模型进行预测时,它只会给我一个黑色的正方形。

我的模型有什么问题?我想念什么吗?这是建筑上的缺陷吗?损失函数/指标错误?班级加权问题?我将不胜感激。

网络体系结构:

from contextlib import redirect_stdout
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Input, BatchNormalization, Activation, Dropout
from tensorflow.python.keras.layers.convolutional import Conv2D, Conv2DTranspose
from tensorflow.python.keras.layers.pooling import MaxPooling2D
from tensorflow.python.keras.layers.merge import concatenate
import tensorflow as tf


config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)
IMAGE_WIDTH = 768


def get_unet(input_image, n_filters, kernel_size, dropout=0.5):
    conv_1 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), data_format="channels_last", activation='relu', kernel_initializer="he_normal", padding="same")(input_image)
    conv_1 = BatchNormalization()(conv_1)
    conv_2 = Conv2D(filters=n_filters, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_1)
    conv_2 = BatchNormalization()(conv_2)
    pool_1 = MaxPooling2D(pool_size=(2, 2))(conv_2)
    pool_1 = Dropout(dropout * 0.5)(pool_1)

    conv_3 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_1)
    conv_3 = BatchNormalization()(conv_3)
    conv_4 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_3)
    conv_4 = BatchNormalization()(conv_4)
    pool_2 = MaxPooling2D(pool_size=(2, 2))(conv_4)
    pool_2 = Dropout(dropout)(pool_2)

    conv_5 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_2)
    conv_5 = BatchNormalization()(conv_5)
    conv_6 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_5)
    conv_6 = BatchNormalization()(conv_6)
    pool_3 = MaxPooling2D(pool_size=(2, 2))(conv_6)
    pool_3 = Dropout(dropout)(pool_3)

    conv_7 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_3)
    conv_7 = BatchNormalization()(conv_7)
    conv_8 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_7)
    conv_8 = BatchNormalization()(conv_8)
    pool_4 = MaxPooling2D(pool_size=(2, 2))(conv_8)
    pool_4 = Dropout(dropout)(pool_4)

    conv_9 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(pool_4)
    conv_9 = BatchNormalization()(conv_9)
    conv_10 = Conv2D(filters=n_filters * 16, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_9)
    conv_10 = BatchNormalization()(conv_10)

    upconv_1 = Conv2DTranspose(n_filters * 8, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_10)
    concat_1 = concatenate([upconv_1, conv_8])
    concat_1 = Dropout(dropout)(concat_1)
    conv_11 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_1)
    conv_11 = BatchNormalization()(conv_11)
    conv_12 = Conv2D(filters=n_filters * 8, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_11)
    conv_12 = BatchNormalization()(conv_12)

    upconv_2 = Conv2DTranspose(n_filters * 4, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_12)
    concat_2 = concatenate([upconv_2, conv_6])
    concat_2 = Dropout(dropout)(concat_2)
    conv_13 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_2)
    conv_13 = BatchNormalization()(conv_13)
    conv_14 = Conv2D(filters=n_filters * 4, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_13)
    conv_14 = BatchNormalization()(conv_14)

    upconv_3 = Conv2DTranspose(n_filters * 2, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_14)
    concat_3 = concatenate([upconv_3, conv_4])
    concat_3 = Dropout(dropout)(concat_3)
    conv_15 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_3)
    conv_15 = BatchNormalization()(conv_15)
    conv_16 = Conv2D(filters=n_filters * 2, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_15)
    conv_16 = BatchNormalization()(conv_16)

    upconv_4 = Conv2DTranspose(n_filters * 1, (kernel_size, kernel_size), strides=(2, 2), padding='same')(conv_16)
    concat_4 = concatenate([upconv_4, conv_2])
    concat_4 = Dropout(dropout)(concat_4)
    conv_17 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(concat_4)
    conv_17 = BatchNormalization()(conv_17)
    conv_18 = Conv2D(filters=n_filters * 1, kernel_size=(kernel_size, kernel_size), activation='relu', kernel_initializer="he_normal", padding="same")(conv_17)
    conv_18 = BatchNormalization()(conv_18)

    conv_19 = Conv2D(1, (1, 1), activation='sigmoid')(conv_18)
    model = Model(inputs=input_image, outputs=conv_19)
    return model


input_image = Input((IMAGE_WIDTH, IMAGE_WIDTH, 3), name='img')
model = get_unet(input_image, n_filters=16, kernel_size = 3, dropout=0.05)

with open('binary_unet_summary.txt', 'w') as f:
    with redirect_stdout(f):
        model.summary()

model_json = model.to_json()
with open("my_basic_unet.json", "w") as json_file:
    json_file.write(model_json)

数据生成器和其他功能:

def calc_weights(masks_folder):
    """Calculate class weights according to classes distribution in a dataset"""
    images_list = os.listdir(masks_folder)
    class_1_numbers = []
    for i in range(len(images_list)):
        mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
        class_1_numbers.append(cv2.countNonZero(mask))

    class_1_total = int(statistics.median(class_1_numbers))
    class_0_total = int(IMAGE_WIDTH**2 - class_1_total)
    class_1_weight = 1. # Maximum value to minority class
    class_0_weight = class_1_total / class_0_total # Proportional value to majority class for classes balance
    return [class_0_weight, class_1_weight]


def data_gen(templates_folder, masks_folder, image_width, batch_size):
    """Generate individual batches form dataset"""
    counter = 0
    images_list = os.listdir(templates_folder)
    random.shuffle(images_list)
    while True:
        templates_pack = np.zeros((batch_size, image_width, image_width, 3)).astype('float')
        masks_pack = np.zeros((batch_size, image_width, image_width, 1)).astype('float')
        for i in range(counter, counter + batch_size):
            template = cv2.imread(templates_folder + '/' + images_list[i]) / 255.
            templates_pack[i - counter] = template

            mask = cv2.imread(masks_folder + '/' + images_list[i], cv2.IMREAD_GRAYSCALE) / 255.
            mask = np.expand_dims(mask, axis=2) # Add extra dimension for parity with template size [738 * 738 * 3]
            masks_pack[i - counter] = mask

        counter += batch_size
        if counter + batch_size >= len(images_list):
            counter = 0
            random.shuffle(images_list)
        yield templates_pack, masks_pack


def _gather_channels(x, indexes):
    """Slice tensor along channels axis by given indexes"""
    if tf.keras.backend.image_data_format() == 'channels_last':
        x = tf.keras.backend.permute_dimensions(x, (3, 0, 1, 2))
        x = tf.keras.backend.gather(x, indexes)
        x = tf.keras.backend.permute_dimensions(x, (1, 2, 3, 0))
    else:
        x = tf.keras.backend.permute_dimensions(x, (1, 0, 2, 3))
        x = tf.keras.backend.gather(x, indexes)
        x = tf.keras.backend.permute_dimensions(x, (1, 0, 2, 3))
    return x


def get_reduce_axes(per_image):
    axes = [1, 2] if tf.keras.backend.image_data_format() == 'channels_last' else [2, 3]
    if not per_image:
        axes.insert(0, 0)
    return axes


def gather_channels(*xs, indexes=None):
    """Slice tensors along channels axis by given indexes"""
    if indexes is None:
        return xs
    elif isinstance(indexes, (int)):
        indexes = [indexes]
    xs = [_gather_channels(x, indexes=indexes) for x in xs]
    return xs


def round_if_needed(x, threshold):
    if threshold is not None:
        x = tf.keras.backend.greater(x, threshold)
        x = tf.keras.backend.cast(x, tf.keras.backend.floatx())
    return x


def average(x, per_image=False, class_weights=None):
    if per_image:
        x = tf.keras.backend.mean(x, axis=0)
    if class_weights is not None:
        x = x * class_weights
    return tf.keras.backend.mean(x)


def jaccard_metric(gt_mask, pred_mask, class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    r""" 
    Args:
        gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
        pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
        class_weights: 1. or list of class weights, len(weights) = C
        class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
        smooth: value to avoid division by zero
        per_image: if ``True``, metric is calculated as mean over images in batch (B),
            else over whole batch
        threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction will not be round

    Returns:
        IoU/Jaccard score in range [0, 1]
        """
    gt_mask, pred_mask = gather_channels(gt_mask, pred_mask, indexes=class_indexes)
    pred_mask = round_if_needed(pred_mask, threshold)
    axes = get_reduce_axes(per_image)

    # score calculation
    intersection = tf.keras.backend.sum(gt_mask * pred_mask, axis=axes)
    union = tf.keras.backend.sum(gt_mask + pred_mask, axis=axes) - intersection

    score = (intersection + smooth) / (union + smooth)
    score = average(score, per_image, class_weights)

    return score


def jaccard_loss(gt_mask, pred_mask, class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    return 1 - jaccard_metric(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)


def jaccard_loss_wraper(class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    def jaccard_loss_keras(gt_mask, pred_mask):
        return jaccard_loss(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)

    return jaccard_loss_keras


def jaccard_metric_wraper(class_weights=1., class_indexes=None, smooth=1e-5, per_image=False, threshold=None):
    def jaccard_metric_keras(gt_mask, pred_mask):
        return jaccard_metric(gt_mask, pred_mask, class_weights=class_weights, class_indexes=class_indexes, smooth=smooth, per_image=per_image, threshold=threshold)

    return jaccard_metric_keras

模型参数:

IMAGE_WIDTH = 768
callbacks = [
    EarlyStopping(patience=5, verbose=1),
    ReduceLROnPlateau(factor=0.1, patience=3, min_lr=0.00001, verbose=1),
    ModelCheckpoint("best_model.h5", verbose=1, save_best_only=True, save_weights_only=False)
]
train_templates_path = f"E:/Explorium/images/train/templates"
train_masks_path = f"E:/Explorium/images/train/masks"
valid_templates_path = f"E:/Explorium/images/valid/templates"
valid_masks_path = f"E:/Explorium/images/valid/masks"
TRAIN_SET_SIZE = len(os.listdir(train_templates_path))
VALID_SET_SIZE = len(os.listdir(valid_templates_path))
BATCH_SIZE = 4
EPOCHS = 100
STEPS_PER_EPOCH = TRAIN_SET_SIZE / BATCH_SIZE
VALIDATION_STEPS = VALID_SET_SIZE / BATCH_SIZE
train_generator = data_gen(train_templates_path, train_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)
val_generator = data_gen(valid_templates_path, valid_masks_path, IMAGE_WIDTH, batch_size = BATCH_SIZE)

# LOADING ARCHITECTURE AND COMPILING
json_file = open('my_basic_unet.json', 'r')
loaded_model_json = json_file.read()
json_file.close()
model = model_from_json(loaded_model_json)

weights = calc_weights(train_masks_path)
loss_function = jaccard_loss_wraper(class_weights=weights)
metric = jaccard_metric_wraper(class_weights=weights)

model.compile(optimizer=Adam(lr=0.0001), loss=loss_function, metrics=[metric])

# TRAINING
print("VERSION CHECK:", tf.__version__, tf.test.is_built_with_cuda(), device_lib.list_local_devices(), sep="\n")
history = model.fit_generator(train_generator, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, validation_data=val_generator, validation_steps=VALIDATION_STEPS, callbacks=callbacks)

0 个答案:

没有答案