性别分类-VGG模型

时间:2018-11-11 10:34:49

标签: keras neural-network deep-learning computer-vision conv-neural-network

我正在使用以下代码对性别进行分类(男vs女)。但是,它的过拟合和阀门精度甚至达不到90%。需要您的建议。

img_width, img_height =128,128

top_model_weights_path = 'bottleneck_fc_model.h5'
train_data_dir = 'Train'
validation_data_dir = 'Test'
nb_train_samples = 30000
nb_validation_samples = 7000
epochs = 150
batch_size = 128


def save_bottlebeck_features():
    datagen = ImageDataGenerator(rescale=1. / 255)

    # build the VGG16 network
    model = applications.VGG16(include_top=False, weights='imagenet')

    generator = datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_width, img_height),
        batch_size=batch_size,
        class_mode=None,
        shuffle=False)
    predict_size_train = int(math.ceil(nb_train_samples / batch_size))
    bottleneck_features_train = model.predict_generator(generator, predict_size_train)
    np.save('bottleneck_features_train.npy',
            bottleneck_features_train)

    generator = datagen.flow_from_directory(
        validation_data_dir,
        target_size=(img_width, img_height),
        batch_size=batch_size,
        class_mode=None,
        shuffle=False)
    predict_size_validation = int(math.ceil(nb_validation_samples / batch_size))
    bottleneck_features_validation = model.predict_generator(generator, predict_size_validation)
    np.save('bottleneck_features_validation.npy',
            bottleneck_features_validation)


def train_top_model():
    train_data = np.load('bottleneck_features_train.npy')
    train_labels = np.array(
        [0] * (nb_train_samples // 2) + [1] * (nb_train_samples // 2))

    validation_data = np.load('bottleneck_features_validation.npy')
    validation_labels = np.array(
        [0] * (nb_validation_samples // 2) + [1] * (nb_validation_samples // 2))

    model = Sequential()
    model.add(Flatten(input_shape=train_data.shape[1:]))
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy', metrics=['accuracy'])

    model.fit(train_data, train_labels,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(validation_data, validation_labels))
    model.save_weights(top_model_weights_path)


save_bottlebeck_features()
train_top_model()

这是最后一个时期

Epoch 130/150
loss: 0.0337 - acc: 0.9902 - val_loss: 1.1683 - val_acc: 0.8356
Epoch 131/150
loss: 0.0307 - acc: 0.9919 - val_loss: 1.0721 - val_acc: 0.8345
Epoch 132/150
loss: 0.0313 - acc: 0.9914 - val_loss: 1.1606 - val_acc: 0.8342
Epoch 133/150
loss: 0.0316 - acc: 0.9914 - val_loss: 1.1487 - val_acc: 0.8347
Epoch 134/150
loss: 0.0311 - acc: 0.9909 - val_loss: 1.1363 - val_acc: 0.8356
Epoch 135/150
loss: 0.0295 - acc: 0.9914 - val_loss: 1.2289 - val_acc: 0.8355
Epoch 136/150
loss: 0.0325 - acc: 0.9912 - val_loss: 1.1787 - val_acc: 0.8345
Epoch 137/150
loss: 0.0276 - acc: 0.9922 - val_loss: 1.2281 - val_acc: 0.8337
Epoch 138/150
loss: 0.0314 - acc: 0.9918 - val_loss: 1.1973 - val_acc: 0.8352
Epoch 139/150
loss: 0.0298 - acc: 0.9913 - val_loss: 1.1551 - val_acc: 0.8311
Epoch 140/150
loss: 0.0301 - acc: 0.9919 - val_loss: 1.2301 - val_acc: 0.8339
Epoch 141/150
loss: 0.0315 - acc: 0.9917 - val_loss: 1.1344 - val_acc: 0.8328
Epoch 142/150
loss: 0.0290 - acc: 0.9918 - val_loss: 1.2094 - val_acc: 0.8286
Epoch 143/150
loss: 0.0292 - acc: 0.9919 - val_loss: 1.1449 - val_acc: 0.8358
Epoch 144/150
loss: 0.0284 - acc: 0.9925 - val_loss: 1.2666 - val_acc: 0.8267
Epoch 145/150
loss: 0.0328 - acc: 0.9913 - val_loss: 1.1720 - val_acc: 0.8331
Epoch 146/150
loss: 0.0270 - acc: 0.9928 - val_loss: 1.2077 - val_acc: 0.8355
Epoch 147/150
loss: 0.0338 - acc: 0.9907 - val_loss: 1.2715 - val_acc: 0.8313
Epoch 148/150
loss: 0.0276 - acc: 0.9923 - val_loss: 1.3014 - val_acc: 0.8223
Epoch 149/150
loss: 0.0290 - acc: 0.9923 - val_loss: 1.2123 - val_acc: 0.8291
Epoch 150/150
loss: 0.0317 - acc: 0.9920 - val_loss: 1.2682 - val_acc: 0.8277

显然,这超出了拟合范围,需要更多数据。但是,对于具有10K数据的猫与狗,此代码有效,并且val精度在4-5个纪元内超过90%。在这方面需要帮助。

1 个答案:

答案 0 :(得分:1)

尝试以下建议(调整直至获得所需结果):

  • 更改优化程序(尝试adam)。
  • 更改学习率(尝试小学习率)
  • 添加正则化L2
  • 添加批次归一化层。
  • 展平后的图层>>为vgg16。