如果将解码器构建为编码器的镜像,则最后一层的输出大小不匹配。
这是模型摘要:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv_1_j (Conv2D) (None, 28, 28, 64) 640
_________________________________________________________________
batch_normalization_v2 (Batc (None, 28, 28, 64) 256
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64) 0
_________________________________________________________________
conv_2_j (Conv2D) (None, 14, 14, 64) 36928
_________________________________________________________________
batch_normalization_v2_1 (Ba (None, 14, 14, 64) 256
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64) 0
_________________________________________________________________
conv_3_j (Conv2D) (None, 7, 7, 64) 36928
_________________________________________________________________
batch_normalization_v2_2 (Ba (None, 7, 7, 64) 256
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 64) 0
_________________________________________________________________
conv_4_j (Conv2D) (None, 3, 3, 64) 36928
_________________________________________________________________
batch_normalization_v2_3 (Ba (None, 3, 3, 64) 256
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 1, 1, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 64) 0
_________________________________________________________________
dense_1_j (Dense) (None, 64) 4160
_________________________________________________________________
reshape_out (Lambda) (None, 1, 1, 64) 0
_________________________________________________________________
conv2d (Conv2D) (None, 1, 1, 64) 36928
_________________________________________________________________
batch_normalization_v2_4 (Ba (None, 1, 1, 64) 256
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 2, 2, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 2, 2, 64) 36928
_________________________________________________________________
batch_normalization_v2_5 (Ba (None, 2, 2, 64) 256
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 4, 4, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 4, 4, 64) 36928
_________________________________________________________________
batch_normalization_v2_6 (Ba (None, 4, 4, 64) 256
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 8, 8, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 8, 8, 64) 36928
_________________________________________________________________
batch_normalization_v2_7 (Ba (None, 8, 8, 64) 256
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 16, 16, 64) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 16, 16, 1) 577
=================================================================
Total params: 265,921
Trainable params: 264,897
Non-trainable params: 1,024
_________________________________________________________________
要复制的代码:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tensorflow.python.keras.layers import Lambda
from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping
def resize(example):
image = example['image']
image = tf.image.resize(image, [28, 28])
image = tf.image.rgb_to_grayscale(image, )
image = image / 255
example['image'] = image
return example
def get_tupel(example):
return example['image'], example['image']
def gen_dataset(dataset, batch_size):
dataset = dataset.map(resize, num_parallel_calls=4)
dataset = dataset.map(get_tupel, num_parallel_calls=4)
dataset = dataset.shuffle(batch_size*50).repeat() # infinite stream
dataset = dataset.prefetch(10000)
dataset = dataset.batch(batch_size)
return dataset
def main():
builder = tfds.builder("cifar10")
builder.download_and_prepare()
datasets = builder.as_dataset()
train_dataset, test_dataset = datasets['train'], datasets['test']
batch_size = 48
train_dataset = gen_dataset(train_dataset, batch_size)
test_dataset = gen_dataset(test_dataset, batch_size)
device = '/cpu:0' if not tf.test.is_gpu_available() else tf.test.gpu_device_name()
print(tf.test.gpu_device_name())
with tf.device(device):
filters = 64
kernel = 3
pooling = 2
image_size = 28
inp_layer = tf.keras.layers.Input(shape=(image_size, image_size, 1))
cnn_embedding_out = cnn_encoder(inp_layer, filters, kernel, pooling)
cnn_decoder_out = cnn_decoder(cnn_embedding_out, filters, kernel, pooling)
model = tf.keras.Model(inputs=inp_layer, outputs=cnn_decoder_out)
model.compile(optimizer=tf.optimizers.Adam(0.0001), loss='binary_crossentropy',
metrics=['accuracy'])
print(model.summary())
model.fit(train_dataset, validation_data=test_dataset,
steps_per_epoch=100, # 1000
validation_steps=100,
epochs=150,)
def cnn_encoder(inp_layer, filters, kernel, pooling):
cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_1_j')(inp_layer)
bn1 = tf.keras.layers.BatchNormalization()(cnn1)
max1 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn1)
cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_2_j')(max1)
bn2 = tf.keras.layers.BatchNormalization()(cnn2)
max2 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn2)
cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_3_j')(max2)
bn3 = tf.keras.layers.BatchNormalization()(cnn3)
max3 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn3)
cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_4_j')(max3)
bn4 = tf.keras.layers.BatchNormalization()(cnn4)
max4 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn4)
flat = tf.keras.layers.Flatten()(max4)
fc = tf.keras.layers.Dense(64, name='dense_1_j')(flat) # this is the encoder layer!
return fc
def cnn_decoder(inp_layer, filters, kernel, pooling):
res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
bn1 = tf.keras.layers.BatchNormalization()(cnn1)
up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
bn2 = tf.keras.layers.BatchNormalization()(cnn2)
up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
bn3 = tf.keras.layers.BatchNormalization()(cnn3)
up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up3)
bn4 = tf.keras.layers.BatchNormalization()(cnn4)
up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
decoded = tf.keras.layers.Conv2D(1, kernel, padding="same", activation='sigmoid')(up4)
return decoded
def reshape(dim, name="", complete=False):
def func(x):
if complete:
ret = tf.reshape(x, dim)
else:
ret = tf.reshape(x, [-1, ] + dim)
return ret
return Lambda(func, name=name)
if __name__ == "__main__":
main()
我尝试使用Conv2dTranspose和不同大小的上采样,但这并不正确。
我希望将输出作为输入(48,28,28,1)
我在做什么错了?
答案 0 :(得分:2)
正如您在reshape_out (Lambda)
中看到的那样,您拥有形状1,因此通过执行操作
像UpSampling2
一样,您只需调整大小2, 4, 8, 16, 32.
就可以按照自己的方式做,直到完成大小UpSampling2
为止,32
,然后再做两次{{1} Conv2D
中的}得到形状(x,28,28,64)的结果
答案 1 :(得分:0)
我更改了cnn_decoder
函数:
def cnn_decoder(inp_layer, filters, kernel, pooling):
res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
bn1 = tf.keras.layers.BatchNormalization()(cnn1)
up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
bn2 = tf.keras.layers.BatchNormalization()(cnn2)
up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
bn3 = tf.keras.layers.BatchNormalization()(cnn3)
up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up3)
bn4 = tf.keras.layers.BatchNormalization()(cnn4)
up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
cnn5 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up4)
bn5 = tf.keras.layers.BatchNormalization()(cnn5)
up5 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn5)
cnn6 = tf.keras.layers.Conv2D(filters, kernel, padding="valid", activation='relu')(up5)
bn5 = tf.keras.layers.BatchNormalization()(cnn6)
decoded = tf.keras.layers.Conv2D(1, kernel, padding="valid", activation='sigmoid')(bn5)
return decoded
您觉得这对吗?