在训练多输入Keras模型时如何提供不同的数据格式

时间:2017-12-02 21:19:03

标签: tensorflow keras keras-layer

我正在尝试构建用于分类医学图像的多输入Keras模型。多输入将包括(i)要通过CNN的原始图像,以及(ii)包括作为辅助输入的传统计算机视觉特征。然后将它们连接起来并使用输入到密集层的小神经网络。请参考下面的图片,了解模型架构的描述:

multi-input model architecture

问题是我无法使用model.fit训练模型,因为我有一个大的图像数据集(超过一百万),这不适合内存。因此,我需要使用model.fit_generator,这使我能够访问将从目录中读取图像的flow_from_directory调用。但是,现在的问题是辅助输入(手动计算机视觉功能,它只是每个图像11个特征的矢量)无法通过model.fit_generator输入。

如何训练此自定义网络?

如果我的问题需要更具体或需要更多信息,请告诉我。

1 个答案:

答案 0 :(得分:0)

解决方案是在多输入模型的中间使用连接层。参见7.1.2 of Deep Learning with Python

部分
from keras.models import Model
from keras import layers
from keras import Input

text_vocabulary_size = 10000
question_vocabulary_size = 10000
answer_vocabulary_size = 500

# The text input is a variable-length sequence of integers. 
# Note that you can optionally name the inputs.
text_input = Input(shape=(None,), dtype='int32', name='text')
# Embeds the inputs into a sequence of vectors of size 64
# embedded_text = layers.Embedding(64, text_vocabulary_size)(text_input)
# embedded_text = layers.Embedding(output_dim=64, input_dim=text_vocabulary_size)(text_input)
embedded_text = layers.Embedding(text_vocabulary_size,64)(text_input)
# Encodes the vectors in a single vector via an LSTM
encoded_text = layers.LSTM(32)(embedded_text)
# Same process (with different layer instances) for the question
question_input = Input(shape=(None,),dtype='int32',name='question')
# embedded_question = layers.Embedding(32, question_vocabulary_size)(question_input)
# embedded_question = layers.Embedding(output_dim=32, input_dim=question_vocabulary_size)(question_input)
embedded_question = layers.Embedding(question_vocabulary_size,32)(question_input)
encoded_question = layers.LSTM(16)(embedded_question) 
# Concatenates the encoded question and encoded text
concatenated = layers.concatenate([encoded_text, encoded_question],axis=-1)
# Adds a softmax classifier on top
answer = layers.Dense(answer_vocabulary_size, activation='softmax')(concatenated)
# At model instantiation, you specify the two inputs and the output.
model = Model([text_input, question_input], answer)
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['acc'])