SegNet:为什么训练精度高而验证精度低?

时间:2019-03-23 08:59:49

标签: tensorflow machine-learning keras neural-network deep-learning

我尝试在Berkeley DeepDrive数据集上训练在Keras中实现的SegNet-Basic。但是在训练之后,验证准确度会稳定在50-60%左右,而训练准确度则会上升到90%以上。

我尝试使用不同的优化器和图片尺寸。数据集的遮罩以灰度图像的形式出现,我根据provided labels将它们分成一个numpy数组,每张图片[img_height,img_width,num_classes]。

我的SegNet_Basic实现如下所示:_encode和_decode函数是Unpool的实现,其中paper声明了索引。

def SegNet_Basic(n_labels,input_shape= (480,352,3)):
    kernel = 3
    filter_size = 64
    inputs = Input(shape=input_shape)

    c1 = Conv2D(filter_size, (3,3),activation='relu',padding='same')(inputs)
    c1 = BatchNormalization()(c1)

    pool_1, mask_1 = _encode(c1)

    c2 = Dropout(0.2)(pool_1)
    c2 = Conv2D(128, (3,3),activation='relu', padding='same')(c2)
    c2 = BatchNormalization()(c2)

    pool_2, mask_2 = _encode(c2)

    c3 = Conv2D(256, (3,3),activation='relu', padding='same')(pool_2)
    c3 = BatchNormalization()(c3)

    pool_3, mask_3 = _encode(c3)

    c4 = Dropout(0.2)(pool_3)
    c4 = Conv2D(512, (3,3),activation='relu', padding='same')(c4)
    c4 = BatchNormalization()(c4)

    c5 = Conv2D(512, (3,3), activation='relu',padding='same')(c4)
    c5 = BatchNormalization()(c5)

    c6 = Conv2D(256, (3, 3), activation='relu', padding='same')(c5)
    unpool_1 = _decode(c6,mask_3)

    c6 = Dropout(0.2)(unpool_1)
    c6 = Conv2D(256, (3,3),activation='relu', padding='same')(c6)
    c6 = BatchNormalization()(c6)

    c7 = Conv2D(128, (3, 3), activation='relu', padding='same')(c6)
    unpool_2 = _decode(c7,mask_2)

    c7 = Dropout(0.2)(unpool_2)
    c7 = Conv2D(128,(3,3), activation='relu',padding='same')(c7)
    c7 = BatchNormalization()(c7)

    c8 = Conv2D(64, (3, 3), activation='relu', padding='same')(c7)
    unpool_3 = _decode(c8,mask_1)

    c8 = Conv2D(filter_size,(3,3),activation='relu', padding='same')(unpool_3)
    c8 = BatchNormalization()(c8)
    output = Conv2D(n_labels, (1, 1), activation='softmax')(c8)

    model = Model(inputs=[inputs], output=[output], name='SegNet_Basic')

    return model

我使用的图片大小为(480,352),因为我的笔记本电脑无法在gpu上处理它们,因此我无法使用更大的尺寸。我还看到,如我添加的第一张图片所示,像天空和树木这样的类得到了很好的认可,但是其他类的准确性却非常糟糕。

如果需要,我还可以提供代码的其他部分!

谢谢您的帮助!

predicted output per class

model accuracy after training

model loss after training

Output of a training picture after training 100 epochs

0 个答案:

没有答案