我正在尝试创建一个与子模型具有相同输入的集成模型。
models = list()
nb_models = 3
#load all sub models
for i in range(nb_models):
model_tmp = load_model("lstm_model"+str(i+1)+".h5")
model_tmp.name = "model_"+str(i+1)
models.append(model_tmp)
def create_ensemble(models,model_input):
# take-in all outputs fro all models
outModels = [model(model_input) for model in models]
# calculate average of all results
outAvg = layers.average(outModels)
# merge into one model
modelMerge = Model(inputs=model_input,outputs=outAvg,name='ensemble')
return modelMerge
model_input = Input(shape=models[0].input_shape[1:])
modelEns = create_ensemble(models,model_input)
当我加载集合模型并向其提供与单独的子模型相同的数据时,出现以下错误。
您必须使用dtype float和形状[1,1,1] [[{{node lstm_2_input}}]]]的占位符张量'lstm_2_input'输入值
对于三个子模型,它们具有:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (1, 1) 12
_________________________________________________________________
dense_1 (Dense) (1, 1) 2
=================================================================
对于集成模型:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 1, 1) 0
__________________________________________________________________________________________________
model_1 (Sequential) multiple 14 input_1[0][0]
__________________________________________________________________________________________________
model_2 (Sequential) multiple 14 input_1[0][0]
__________________________________________________________________________________________________
model_3 (Sequential) multiple 14 input_1[0][0]
__________________________________________________________________________________________________
average_1 (Average) (None, 1) 0 model_1[1][0]
model_2[1][0]
model_3[1][0]
==================================================================================================
test_reshaped.shape()
(28, 1, 1)
答案 0 :(得分:0)
请参见此示例,摘自here
# Multiple Inputs
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
第一个输入模型
visible1 = Input(shape=(64,64,1))
conv11 = Conv2D(32, kernel_size=4, activation='relu')(visible1)
pool11 = MaxPooling2D(pool_size=(2, 2))(conv11)
conv12 = Conv2D(16, kernel_size=4, activation='relu')(pool11)
pool12 = MaxPooling2D(pool_size=(2, 2))(conv12)
flat1 = Flatten()(pool12)
第二种输入模型
visible2 = Input(shape=(32,32,3))
conv21 = Conv2D(32, kernel_size=4, activation='relu')(visible2)
pool21 = MaxPooling2D(pool_size=(2, 2))(conv21)
conv22 = Conv2D(16, kernel_size=4, activation='relu')(pool21)
pool22 = MaxPooling2D(pool_size=(2, 2))(conv22)
flat2 = Flatten()(pool22)
合并输入模型
merge = concatenate([flat1, flat2])
# interpretation model
hidden1 = Dense(10, activation='relu')(merge)
hidden2 = Dense(10, activation='relu')(hidden1)
output = Dense(1, activation='sigmoid')(hidden2)
model = Model(inputs=[visible1, visible2], outputs=output)
# summarize layers
print(model.summary())
模型摘要
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
input_1 (InputLayer) (None, 64, 64, 1) 0
____________________________________________________________________________________________________
input_2 (InputLayer) (None, 32, 32, 3) 0
____________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 61, 61, 32) 544 input_1[0][0]
____________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 29, 29, 32) 1568 input_2[0][0]
____________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 30, 30, 32) 0 conv2d_1[0][0]
____________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 14, 14, 32) 0 conv2d_3[0][0]
____________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 27, 27, 16) 8208 max_pooling2d_1[0][0]
____________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 11, 11, 16) 8208 max_pooling2d_3[0][0]
____________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 13, 13, 16) 0 conv2d_2[0][0]
____________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 5, 5, 16) 0 conv2d_4[0][0]
____________________________________________________________________________________________________
flatten_1 (Flatten) (None, 2704) 0 max_pooling2d_2[0][0]
____________________________________________________________________________________________________
flatten_2 (Flatten) (None, 400) 0 max_pooling2d_4[0][0]
____________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 3104) 0 flatten_1[0][0]
flatten_2[0][0]
____________________________________________________________________________________________________
dense_1 (Dense) (None, 10) 31050 concatenate_1[0][0]
____________________________________________________________________________________________________
dense_2 (Dense) (None, 10) 110 dense_1[0][0]
____________________________________________________________________________________________________
dense_3 (Dense) (None, 1) 11 dense_2[0][0]
====================================================================================================
Total params: 49,699
Trainable params: 49,699
Non-trainable params: 0
绘图图
plot_model(model, to_file='multiple_inputs.png')