keras多输入模型不起作用

时间:2018-06-27 03:26:28

标签: keras deep-learning

我有一个模型,可以使用kaggle的cats_vs_dogs数据集来区分猫和狗。我尝试了两种方法来做到这一点。对于第一个,我使用了三个现有模型(ResNet50,Xception InceptionV3)来提取特征,然后将训练数据通过这些模型的卷积基础进行预测和连接,然后将其用于独立的密集连接分类器。很好,经过五次训练后,val_acc变为99.58%。然后,我想使用数据扩充和微调,因此我在顶部添加了层,并在输入数据上端到端运行了整个模型,从而扩展了这三个模型。奇怪的是,第二种方法在训练中获得了不错的结果,但在验证中却很糟糕,并且val_acc始终为常数(0.5)。我感到非常困惑,这两种方式为什么会有如此不同的结果。 这是我的代码

from keras.models import *
from keras.layers import *
    from keras.applications import *
    from keras.preprocessing.image import *

    res_net_input = Input((224, 224, 3), name='res_net')
    res_net_base_model = ResNet50(input_tensor=res_net_input, weights='imagenet', include_top=False)
    for layers in res_net_base_model.layers:
        layers.trainable = False

    xception_input = Input((299, 299, 3), name='xception')
    xception_base_model = Xception(input_tensor=xception_input, weights='imagenet', include_top=False)
    for layers in xception_base_model.layers:
        layers.trainable = False

    inception_input = Input((299, 299, 3), name='inception')
    inception_base_model = InceptionV3(input_tensor=inception_input, weights='imagenet', include_top=False)
    for layers in inception_base_model.layers:
        layers.trainable = False

    res_result = GlobalAveragePooling2D()(res_net_base_model.output)
    xcp_result = GlobalAveragePooling2D()(xception_base_model.output)
    icp_result = GlobalAveragePooling2D()(inception_base_model.output)

    concatenated = concatenate([res_result, xcp_result, icp_result], axis=1)

    x = Dropout(0.5)(concatenated)
    x = Dense(1, activation='sigmoid')(x)
    model = Model([res_net_base_model.input, xception_base_model.input, inception_base_model.input], x)
    model.compile(optimizer='adadelta',
              loss='binary_crossentropy',
              metrics=['accuracy'])


    train_imgen = ImageDataGenerator(rescale = 1./255, 
                                    shear_range = 0.2, 
                                    zoom_range = 0.2,
                                    rotation_range=5.,
                                   horizontal_flip = True)

    validation_imgen = ImageDataGenerator(rescale = 1./255)


    def generate_generator_multiple(generator,dir1, batch_size, img_size1, img_size2, img_size3):
        genX1 = generator.flow_from_directory(dir1,
                                          target_size = (img_size1[0],img_size1[1]),
                                          class_mode = 'binary',
                                          batch_size = batch_size,
                                          shuffle=False, 
                                          )

        genX2 = generator.flow_from_directory(dir1,
                                          target_size = (img_size2[0],img_size2[1]),
                                          class_mode = 'binary',
                                          batch_size = batch_size,
                                          shuffle=False, 
                                          seed=7)
        genX3 = generator.flow_from_directory(dir1,
                                          target_size = (img_size3[0],img_size3[1]),
                                          class_mode = 'binary',
                                          batch_size = batch_size,
                                          shuffle=False, 
                                          seed=7)
        while True:
            X1i = genX1.next()
            X2i = genX2.next()
            X3i = genX3.next()
            yield [X1i[0], X2i[0],X3i[0]], X1i[1]  

    tain_generator = generate_generator_multiple(train_imgen , '/output/keras/dog_vs_cat_full/train', 100, (224,224), (299, 299),   (299, 299))
    validation_generator = generate_generator_multiple(validation_imgen,'/output/keras/dog_vs_cat_full/validation', 100,    (224,224), (299, 299), (299, 299))


    history=model.fit_generator(tain_generator,
                        steps_per_epoch=200,
                        epochs = 5,
                        validation_data = validation_generator,
                        validation_steps = 50,
                        shuffle=False) 

0 个答案:

没有答案