如何将单独创建的编码器和解码器连接在一起?

时间:2019-10-08 14:17:48

标签: machine-learning keras autoencoder

我想创建AutoEncoder并使用它。 我已经创建了编码器和解码器。使用summary()方法,我检查了它们是否良好且对称。 下一步,我想将编码器连接到解码器并接收最终模型以对其进行训练。 我尝试将两者连接在一起,但最终模型摘要显示出了问题。它没有显示模型的所有层(缺少解码器的某些部分)。 您可以看到buld_model()函数。在这里,我试图将两者连接在一起。 谢谢!

enter image description here

class AutoEncoder(object):
    def __init__(self,
                 shape,
                 encoder_depth,
                 decoder_depth,
                 encoder_filters,
                 decoder_filters,
                 encoder_strides,
                 decoder_strides,
                 latent_space,
                ):

        self.shape = shape

        #Check provided filters
        assert(encoder_depth == len(encoder_filters), "Counts of filters shoud match find depth of Encoder")
        assert(decoder_depth == len(decoder_filters), "Counts of filters shoud match find depth of Decoder")

        #Check provided strides
        assert(encoder_depth == len(encoder_strides), "Counts of strides shoud match find depth of Encoder")
        assert(decoder_depth == len(decoder_strides), "Counts of strides shoud match find depth of Decoder")

        #Deepth and latent space
        self.encoder_depth = encoder_depth
        self.decoder_depth = decoder_depth
        self.latent_space = latent_space

        #Filters
        self.encoder_filters = encoder_filters
        self.decoder_filters = decoder_filters

        #Strides
        self.encoder_strides = encoder_strides
        self.decoder_strides = decoder_strides

        self.buld_model()

    def build_encoder(self):
        input_x = Input(shape=self.shape, name="encoder_input")
        x = input_x

        for i in range(self.encoder_depth):
            x = Conv2D(self.encoder_filters[i],
                       kernel_size = 3,
                       strides = self.encoder_strides[i],
                       padding="same",
                       name = "encoder_conv_" + str(i))(x)
            x = LeakyReLU()(x)

        self.shape_before_flat = K.int_shape(x)[1:]
        x = Flatten()(x)
        encoder_output = Dense(self.latent_space, name="Encoder_output")(x)

        self.encoder_output = encoder_output
        self.encoder_input = input_x

        self.encoder = tf.keras.Model(self.encoder_input , self.encoder_output)

    def build_decoder(self):
        decoder_input = Input(shape = (self.latent_space,), name="decoder_input")

        x = Dense(np.prod(self.shape_before_flat))(decoder_input)
        x = Reshape(self.shape_before_flat)(x)

        for i in range(self.decoder_depth):
            x = Conv2DTranspose(self.decoder_filters[i],
                                kernel_size = 3,
                                strides = self.decoder_strides[i],
                                padding="same",
                                name = "decoder_conv_t_" + str(i))(x)
            if i < self.decoder_depth - 1:
                x = LeakyReLU()(x)
            else:
                x = Activation("sigmoid")(x)

        decoder_output = x        
        self.decoder = tf.keras.Model(decoder_input , decoder_output)

    def buld_model(self):
        self.build_encoder()
        self.build_decoder()

        model_input = self.encoder_input
        model_output = self.decoder(self.encoder_output)

        self.model = tf.keras.Model(model_input,model_output)
        self.model.compile(optimizer="adam",loss = loss )

def loss(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred),axis=[1,2,3])

autoencoder = AutoEncoder((28,28,1),
                          4,
                          4,
                          [32,64,64,64],
                          [64,64,32,1],
                          [1,2,2,1],
                          [1,2,2,1],
                          2)

autoencoder.model.summary()

1 个答案:

答案 0 :(得分:1)

您几乎明白了!现在小菜一碟。它应该看起来像这样:

    img = Input(shape=self.img_shape)
    encoded_repr = self.encoder(img)
    reconstructed_img = self.decoder(encoded_repr)

    self.autoencoder = Model(img, reconstructed_img)
    self.autoencoder.compile(loss='mse', optimizer=optimizer)

希望对您有帮助!