将CNN和DNN输出合并为输入到keras中的新网络

时间:2019-06-02 11:04:44

标签: python keras

我正在尝试训练具有多输入类型的复杂模型。 问题-损失没有减少,我认为这是由于网络拓扑错误造成的。 我的代码正确吗?

模型:

输入:

a。 6张图片。 5个内容相似,1个是随机图像。

b。文字查询,描述了5张相似的图片。

输出:

模型需要预测随机图像是否类似于查询图像以及其他5张图像(是或否问题-二进制输出)。

问题-损失没有减少,我认为这是由于网络拓扑错误造成的。

def text_dnn(input_shape): #text query converted to vectors using word2vec 
    model = Sequential()
    model.add(Flatten(input_shape=input_shape)) 
    model.add(Dense(1024, activation='relu', name='text_combination'))
    model.add(Dropout(0.5))
    model.add(Dense(128, activation='relu', name='text_feature_vector'))
    return model

def image_cnn(input_shape,number_of_examples):
    images_inputs = [Input(image_input_shape) for i in range(5+1)]
    vgg_16= VGG16(weights='imagenet' ,input_shape=(224, 224, 3))
    vgg_16= Model(inputs=vgg_16.input, outputs=vgg_16.get_layer('fc2').output)
    model = Sequential()
    model.add(vgg_16)
    image_layer_combination = [model(image_input) for image_input in images_inputs]
    combination_layer  = concatenate(image_layer_combination)
    fc_1 = Dense(2048,activation='relu', name='image_combination')(combination_layer)
    image_model  = Model(inputs=images_inputs,outputs=fc_1)
    return image_model

def mergeCnnModel(cnnModel, cnnModel2):
    merged = concatenate([cnnModel.layers[-1].output, cnnModel2.layers[-1].output])
    dense1 = Dense(1024, activation='relu', name='image_text_combination')(merged)
    drop1 = Dropout(0.5)(dense1)
    dense2 = Dense(128, activation='relu', name='feature_vector')(drop1)
    outputs = Dense(1, activation='sigmoid')(dense2)
    model = Model(inputs=cnnModel.inputs+cnnModel2.inputs, outputs=outputs)
    return model


def get_complex_model(image_input_shape, number_of_examples,text_input_shape):

    #text branch
    text_model = text_dnn(text_input_shape)
    #images branch
    image_model=image_cnn(image_input_shape, number_of_examples)
    #merge branches
    new_model= mergeCnnModel(image_model,text_model )
    return new_model

combined_model  = get_complex_model(image_input_shape,number_of_examples,text_input_shape)

模型是训练,但损失保持不变。

0 个答案:

没有答案