自动编码器:解码器的大小与编码器不同

时间:2019-04-30 15:04:10

标签: python tensorflow keras deep-learning autoencoder

如果将解码器构建为编码器的镜像,则最后一层的输出大小不匹配。

这是模型摘要:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv_1_j (Conv2D)            (None, 28, 28, 64)        640       
_________________________________________________________________
batch_normalization_v2 (Batc (None, 28, 28, 64)        256       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
conv_2_j (Conv2D)            (None, 14, 14, 64)        36928     
_________________________________________________________________
batch_normalization_v2_1 (Ba (None, 14, 14, 64)        256       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv_3_j (Conv2D)            (None, 7, 7, 64)          36928     
_________________________________________________________________
batch_normalization_v2_2 (Ba (None, 7, 7, 64)          256       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 64)          0         
_________________________________________________________________
conv_4_j (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
batch_normalization_v2_3 (Ba (None, 3, 3, 64)          256       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 1, 1, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 64)                0         
_________________________________________________________________
dense_1_j (Dense)            (None, 64)                4160      
_________________________________________________________________
reshape_out (Lambda)         (None, 1, 1, 64)          0         
_________________________________________________________________
conv2d (Conv2D)              (None, 1, 1, 64)          36928     
_________________________________________________________________
batch_normalization_v2_4 (Ba (None, 1, 1, 64)          256       
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 2, 2, 64)          0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 2, 2, 64)          36928     
_________________________________________________________________
batch_normalization_v2_5 (Ba (None, 2, 2, 64)          256       
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 4, 4, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
batch_normalization_v2_6 (Ba (None, 4, 4, 64)          256       
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 64)          36928     
_________________________________________________________________
batch_normalization_v2_7 (Ba (None, 8, 8, 64)          256       
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 1)         577       
=================================================================
Total params: 265,921
Trainable params: 264,897
Non-trainable params: 1,024
_________________________________________________________________

要复制的代码:

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tensorflow.python.keras.layers import Lambda

from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping


def resize(example):
    image = example['image']
    image = tf.image.resize(image, [28, 28])
    image = tf.image.rgb_to_grayscale(image, )
    image = image / 255

    example['image'] = image
    return example


def get_tupel(example):
    return example['image'], example['image']


def gen_dataset(dataset, batch_size):
    dataset = dataset.map(resize, num_parallel_calls=4)
    dataset = dataset.map(get_tupel, num_parallel_calls=4)
    dataset = dataset.shuffle(batch_size*50).repeat()  # infinite stream
    dataset = dataset.prefetch(10000)
    dataset = dataset.batch(batch_size)
    return dataset


def main():
    builder = tfds.builder("cifar10")
    builder.download_and_prepare()
    datasets = builder.as_dataset()
    train_dataset, test_dataset = datasets['train'], datasets['test']

    batch_size = 48

    train_dataset = gen_dataset(train_dataset, batch_size)
    test_dataset = gen_dataset(test_dataset, batch_size)

    device = '/cpu:0' if not tf.test.is_gpu_available() else tf.test.gpu_device_name()
    print(tf.test.gpu_device_name())
    with tf.device(device):
        filters = 64
        kernel = 3
        pooling = 2
        image_size = 28

        inp_layer = tf.keras.layers.Input(shape=(image_size, image_size, 1))
        cnn_embedding_out = cnn_encoder(inp_layer, filters, kernel, pooling)
        cnn_decoder_out = cnn_decoder(cnn_embedding_out, filters, kernel, pooling)

        model = tf.keras.Model(inputs=inp_layer, outputs=cnn_decoder_out)

        model.compile(optimizer=tf.optimizers.Adam(0.0001), loss='binary_crossentropy',
                      metrics=['accuracy'])

        print(model.summary())

        model.fit(train_dataset, validation_data=test_dataset,
                  steps_per_epoch=100,  # 1000
                  validation_steps=100,
                  epochs=150,)


def cnn_encoder(inp_layer, filters, kernel, pooling):
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_1_j')(inp_layer)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    max1 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_2_j')(max1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    max2 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_3_j')(max2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    max3 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_4_j')(max3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    max4 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn4)
    flat = tf.keras.layers.Flatten()(max4)
    fc = tf.keras.layers.Dense(64, name='dense_1_j')(flat)  # this is the encoder layer!

    return fc


def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    decoded = tf.keras.layers.Conv2D(1, kernel,  padding="same", activation='sigmoid')(up4)
    return decoded


def reshape(dim, name="", complete=False):
    def func(x):
        if complete:
            ret = tf.reshape(x, dim)
        else:
            ret = tf.reshape(x, [-1, ] + dim)
        return ret
    return Lambda(func, name=name)


if __name__ == "__main__":
    main()

我尝试使用Conv2dTranspose和不同大小的上采样,但这并不正确。

我希望将输出作为输入(48,28,28,1)

我在做什么错了?

2 个答案:

答案 0 :(得分:2)

正如您在reshape_out (Lambda)中看到的那样,您拥有形状1,因此通过执行操作 像UpSampling2一样,您只需调整大小2, 4, 8, 16, 32.就可以按照自己的方式做,直到完成大小UpSampling2为止,32,然后再做两次{{1} Conv2D中的}得到形状(x,28,28,64)的结果

答案 1 :(得分:0)

我更改了cnn_decoder函数:

def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    cnn5 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up4)
    bn5 = tf.keras.layers.BatchNormalization()(cnn5)
    up5 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn5)
    cnn6 = tf.keras.layers.Conv2D(filters, kernel, padding="valid",  activation='relu')(up5)
    bn5 = tf.keras.layers.BatchNormalization()(cnn6)
    decoded = tf.keras.layers.Conv2D(1, kernel, padding="valid",  activation='sigmoid')(bn5)
    return decoded

您觉得这对吗?