在keras中运行卷积自动编码器时出错

时间:2017-06-13 09:47:08

标签: keras

在keras中运行以下代码时出现错误

Traceback (most recent call last):
  File "my_conv_ae.py", line 74, in <module>
validation_steps = nb_validation_samples // batch_size)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\legacy\interfaces.py", line 88, in wrapper
return func(*args, **kwargs)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\training.py", line 1890, in fit_generator
class_weight=class_weight)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\training.py", line 1627, in train_on_batch
check_batch_axis=True)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\training.py", line 1309, in _standardize_user_data
exception_prefix='target')
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\training.py", line 127, in _standardize_input_data
str(array.shape))
ValueError: Error when checking target: expected conv2d_transpose_8 to have 4 dimensions, but got array with shape (20, 1)

代码是:

import keras
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D, Conv2DTranspose
from keras.models import Model
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
import numpy as np

input_img = Input(shape=(512, 512, 1)) 
nb_train_samples = 1700
nb_validation_samples = 420
epochs = 10
batch_size = 20

x = Conv2D(64, (11, 11), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(input_img)
x = Conv2D(64, (11, 11), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(128, (7, 7), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = Conv2D(128, (5, 5), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(256, (5, 5), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = Conv2D(256, (3, 3), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(512, (3, 3), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = Conv2D(512, (3, 3), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
encoded = MaxPooling2D((2, 2))(x)

print (K.int_shape(encoded))
此时

表示为(26,26,512)

x = UpSampling2D((2, 2))(encoded)
x = Conv2DTranspose(512, (3, 3), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = Conv2DTranspose(512, (3, 3), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2DTranspose(256, (3, 3), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = Conv2DTranspose(256, (5, 5), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2DTranspose(128, (5, 5), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = Conv2DTranspose(128, (7, 7), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2DTranspose(64, (11, 11), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)
decoded = Conv2DTranspose(1, (11, 11), activation='relu', strides= 1, padding='valid', kernel_initializer='glorot_uniform')(x)

print (K.int_shape(decoded))

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer = 'adadelta', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])

train_datagen = ImageDataGenerator(
rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1./255)

x_train = train_datagen.flow_from_directory(
    'data/train',
    target_size = (512, 512), color_mode = 'grayscale',
    batch_size = batch_size, 
    class_mode = 'binary')

x_test = test_datagen.flow_from_directory(
    'data/validation',
    target_size = (512, 512), color_mode = 'grayscale',
    batch_size = batch_size, 
    class_mode = 'binary')

autoencoder.fit_generator(
x_train,
steps_per_epoch = nb_train_samples // batch_size,
epochs = epochs,
validation_data = x_test,
validation_steps = nb_validation_samples // batch_size)

decoded_imgs = autoencoder.predict(x_test)

模型摘要如下:

Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 512, 512, 1)       0
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 502, 502, 64)      7808
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 492, 492, 64)      495680
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 246, 246, 64)      0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 240, 240, 128)     401536
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 236, 236, 128)     409728
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 118, 118, 128)     0
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 114, 114, 256)     819456
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 112, 112, 256)     590080
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 56, 56, 256)       0
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 54, 54, 512)       1180160
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 52, 52, 512)       2359808
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 26, 26, 512)       0
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 52, 52, 512)       0
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 54, 54, 512)       2359808
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 56, 56, 512)       2359808
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 112, 112, 512)     0
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 114, 114, 256)     1179904
_________________________________________________________________
conv2d_transpose_4 (Conv2DTr (None, 118, 118, 256)     1638656
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 236, 236, 256)     0
_________________________________________________________________
conv2d_transpose_5 (Conv2DTr (None, 240, 240, 128)     819328
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 246, 246, 128)     802944
_________________________________________________________________
up_sampling2d_4 (UpSampling2 (None, 492, 492, 128)     0
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 502, 502, 64)      991296
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 512, 512, 1)       7745
=================================================================
Total params: 16,423,745
Trainable params: 16,423,745
Non-trainable params: 0
_________________________________________________________________

请帮帮我。这是因为我用过解码的Conv2DTranspose()吗?

1 个答案:

答案 0 :(得分:0)

这绝对不是模型架构本身的问题(因为它在我身边工作)。看起来像你的地面实况数据的问题。它必须与输入图像具有相同的尺寸,但flow_from_directory不提供此类地面实况数据。我想你需要使用自己的自定义数据生成器。