Keras UpSampling2D行为不一致

时间:2018-09-22 02:19:27

标签: python-3.x tensorflow machine-learning neural-network keras

这是我的模特:

filters = 256
kernel_size = 3
strides = 1
factor = 4  # the factor of upscaling

inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth))
conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer)

res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv1)
act = ReLU()(res)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([conv1, res])

for i in range(15):  # 16-1
    res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
    act = ReLU()(res1)
    res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
    res_rec = Add()([res_rec, res2])

conv2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
a = Add()([conv1, conv2])
up = UpSampling2D(size=4)(a)
outputLayer = Conv2D(filters=3,
                     kernel_size=1,
                     strides=1,
                     padding='same')(up)

model = Model(inputs=inputLayer, outputs=outputLayer)

model.summary()显示:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 350, 350, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 350, 350, 256 7168        input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 350, 350, 256 590080      conv2d_1[0][0]                   
__________________________________________________________________________________________________
re_lu_1 (ReLU)                  (None, 350, 350, 256 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 350, 350, 256 590080      re_lu_1[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 350, 350, 256 590080      add_1[0][0]                      
__________________________________________________________________________________________________
re_lu_2 (ReLU)                  (None, 350, 350, 256 0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 350, 350, 256 590080      re_lu_2[0][0]                    
__________________________________________________________________________________________________
add_2 (Add)                     (None, 350, 350, 256 0           add_1[0][0]                      
                                                                 conv2d_5[0][0]                   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 350, 350, 256 590080      add_2[0][0]                      
__________________________________________________________________________________________________
re_lu_3 (ReLU)                  (None, 350, 350, 256 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 350, 350, 256 590080      re_lu_3[0][0]                    
__________________________________________________________________________________________________
add_3 (Add)                     (None, 350, 350, 256 0           add_2[0][0]                      
                                                                 conv2d_7[0][0]                   

 ...... this goes on for a long time .....



 __________________________________________
add_15 (Add)                    (None, 350, 350, 256 0           add_14[0][0]                     
                                                                 conv2d_31[0][0]                  
__________________________________________________________________________________________________
conv2d_32 (Conv2D)              (None, 350, 350, 256 590080      add_15[0][0]                     
__________________________________________________________________________________________________
re_lu_16 (ReLU)                 (None, 350, 350, 256 0           conv2d_32[0][0]                  
__________________________________________________________________________________________________
conv2d_33 (Conv2D)              (None, 350, 350, 256 590080      re_lu_16[0][0]                   
__________________________________________________________________________________________________
add_16 (Add)                    (None, 350, 350, 256 0           add_15[0][0]                     
                                                                 conv2d_33[0][0]                  
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 350, 350, 256 590080      add_16[0][0]                     
__________________________________________________________________________________________________
add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_34[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            
==================================================================================================
Total params: 19,480,579
Trainable params: 19,480,579
Non-trainable params: 0
__________________________________________________________________________________________________
None

重要的部分在最后,靠近输出:

__________________________________________________________________________________________________
add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                   
                                                                 conv2d_34[0][0]                  
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            
==================================================================================================

现在,看看运行网络时出现的错误:

Traceback (most recent call last):
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 280, in <module>
    setUpImages()
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 96, in setUpImages
    setUpData(trainData, testData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 135, in setUpData
    setUpModel(X_train, Y_train, validateTestData, trainingTestData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 176, in setUpModel
    train(model, X_train, Y_train, validateTestData, trainingTestData)
  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 192, in train
    batch_size=32)
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 950, in fit
    batch_size=batch_size)
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 787, in _standardize_user_data
    exception_prefix='target')
  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training_utils.py", line 137, in standardize_input_data
    str(data_shape))
ValueError: Error when checking target: expected conv2d_35 to have shape (1400, 1400, 1) but got array with shape (1400, 1400, 3)

为什么我的最后一个卷积期望有一个(1400, 1400, 1)张量,却得到一个(1400, 1400, 3)张量,而摘要说UpSampling2D应该返回一个(1400, 1400, 2)张量?

为了澄清一下上下文,我们假设它是一个网络,该网络可以拍摄350x350x3的图像并输出1400x1400x3的图像。

1 个答案:

答案 0 :(得分:0)

显然,错误消息与conv2d_35实体没有特别关系,而是与我的损失函数链接的网络的最后一个实体。

由于我选择了sparse_categorical_crossentropy作为损失函数,因此它期望一个单一的向量。

将损失设置为mean_squared_error即可解决问题。