如何在Keras中将模型拆分和合并在一起?

时间:2018-08-15 13:20:02

标签: machine-learning keras autoencoder

我正在尝试构建由2个自动编码器组成的堆叠式自动编码器模型。我有2个AE,但无法加入。

这是我到目前为止所拥有的

### AUTOENCODER 1 ###
X_input = Input(input_shape)
x = Conv2D(64, (4,1), activation='relu', padding='same')(X_input)
x = Conv2D(32, (3,2), activation='relu', padding='same')(x)
x = MaxPooling2D(name='encoded')(x)
encoded_shape = x.shape.as_list()

x = Conv2D(32, (3,2), activation='relu', padding='same')(x)
x = UpSampling2D(name='up1')(x)
x = Conv2D(64, (4,1), activation='relu', padding='same')(x)
x = Conv2D(1, (3,3), name='decoded', padding='same')(x)
ae1 = Model(X_input, x)

enc_layer_ae1 = ae1.get_layer('encoded').output

-

### AUTOENCODER 2 ###
X_input1 = Input(encoded_shape[1:])
x1 = Conv2D(24, (3,3), activation='relu', padding='same')(X_input1)
x1 = Conv2D(16, (2,2), activation='relu', padding='same')(x1)
x1 = MaxPooling2D((2,3), name='encoded')(x1)

x1 = UpSampling2D((2,3), name='up')(x1)
x1 = Conv2D(16, (2,2), activation='relu', padding='same')(x1)
x1 = Conv2D(24, (3,3), activation='relu', padding='same')(x1)
x1 = Conv2D(32, (1,1), padding='same')(x1)

ae2 = Model(X_input1, x1)

enc_layer_ae2 = ae2.get_layer('encoded').output

这时我想做的是通过堆叠创建另一个模型

  • ae1层,从0到encoded
  • ae2的相同图层
  • 更多Dense

所以最终我的模型应该看起来像ae1_input > ae1_conv2d > ae1_conv2d > ae1_encoded > ae2_input > ae2_conv > ae2_conv > ae2_encoded > dense > softmax

我尝试做类似的事情

ae2_split = Model(X_input1, enc_layer_ae2)

full_output = ae2_split(enc_layer_ae1)
full_output = Dense(150, activation='relu')(full_output)
full_output = Dense(7, activation='softmax')(full_output)

full_model = Model(enc_layer_ae1.input, full_output)

但是我认为这是不正确的。你能建议我做一个合适的方法吗?

谢谢。

1 个答案:

答案 0 :(得分:1)

首先,您应该更改enc_layer_ae2层的输入。由于图层可以在keras中调用,因此您可以轻松地在另一层上调用一个图层。

enc_layer_ae1 = ae1.get_layer('encoded')
enc_layer_ae2 = ae2.get_layer('encoded')

enc_layer_ae2 = enc_layer_ae2(enc_layer_ae1.output)
full_output = Dense(150, activation='relu')(enc_layer_ae2)
full_output = Dense(7, activation='softmax')(full_output)
model = Model(enc_layer_ae1.input, full_output)