我正在尝试构建用于分类医学图像的多输入Keras模型。多输入将包括(i)要通过CNN的原始图像,以及(ii)包括作为辅助输入的传统计算机视觉特征。然后将它们连接起来并使用输入到密集层的小神经网络。请参考下面的图片,了解模型架构的描述:
multi-input model architecture
问题是我无法使用model.fit
训练模型,因为我有一个大的图像数据集(超过一百万),这不适合内存。因此,我需要使用model.fit_generator
,这使我能够访问将从目录中读取图像的flow_from_directory
调用。但是,现在的问题是辅助输入(手动计算机视觉功能,它只是每个图像11个特征的矢量)无法通过model.fit_generator
输入。
如何训练此自定义网络?
如果我的问题需要更具体或需要更多信息,请告诉我。
答案 0 :(得分:0)
解决方案是在多输入模型的中间使用连接层。参见7.1.2 of Deep Learning with Python
部分from keras.models import Model
from keras import layers
from keras import Input
text_vocabulary_size = 10000
question_vocabulary_size = 10000
answer_vocabulary_size = 500
# The text input is a variable-length sequence of integers.
# Note that you can optionally name the inputs.
text_input = Input(shape=(None,), dtype='int32', name='text')
# Embeds the inputs into a sequence of vectors of size 64
# embedded_text = layers.Embedding(64, text_vocabulary_size)(text_input)
# embedded_text = layers.Embedding(output_dim=64, input_dim=text_vocabulary_size)(text_input)
embedded_text = layers.Embedding(text_vocabulary_size,64)(text_input)
# Encodes the vectors in a single vector via an LSTM
encoded_text = layers.LSTM(32)(embedded_text)
# Same process (with different layer instances) for the question
question_input = Input(shape=(None,),dtype='int32',name='question')
# embedded_question = layers.Embedding(32, question_vocabulary_size)(question_input)
# embedded_question = layers.Embedding(output_dim=32, input_dim=question_vocabulary_size)(question_input)
embedded_question = layers.Embedding(question_vocabulary_size,32)(question_input)
encoded_question = layers.LSTM(16)(embedded_question)
# Concatenates the encoded question and encoded text
concatenated = layers.concatenate([encoded_text, encoded_question],axis=-1)
# Adds a softmax classifier on top
answer = layers.Dense(answer_vocabulary_size, activation='softmax')(concatenated)
# At model instantiation, you specify the two inputs and the output.
model = Model([text_input, question_input], answer)
model.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['acc'])