Keras:CNN - Conv2D 不兼容的输入形状

时间:2021-02-03 14:58:23

标签: python keras conv-neural-network

总的来说,我对 CNN 和神经网络架构还很陌生,我正在尝试从 2D 图像创建回归模型。我对数据进行了重新整形,以便 xpix 表示图像的高度,ypix 表示宽度。

xpix = 322
ypix = 76
x_train = x_train.reshape(-1, 332, 76, 1)
x_val = x_val.reshape(-1, 332, 76, 1)
x_test = x_test.reshape(-1, 332, 76, 1)
print ('Training images shape: ', x_train.shape)
print ('Validation images shape: ', x_val.shape)
print ('Test Images shape:', x_test.shape)

Training images shape:  (1387, 332, 76, 1)
Validation images shape:  (463, 332, 76, 1)
Test Images shape: (617, 332, 76, 1)

我试过的模型架构是这样的

Regressor = Sequential()
Regressor.add(Conv2D(64, kernel_size=(3, 3), activation = 'relu', padding='same', input_shape=(xpix,ypix, 1)))
Regressor.add(MaxPooling2D(pool_size=(2, 2),padding='same'))
Regressor.add(BatchNormalization())
Regressor.add(Flatten())
Regressor.add(Dense(64, activation='relu'))
Regressor.add(Dropout(0.1))
Regressor.add(Dense(1, activation='linear'))

Adam= optimizers.Adam(lr=0.0001)
Regressor.compile(loss = 'mean_squared_error', optimizer=Adam)
History = Regressor.fit(x_train, y_train, batch_size=16, epochs=35, validation_data = (x_val, y_val))

但我收到以下错误消息

Epoch 1/35
WARNING:tensorflow:Model was constructed with shape (None, 322, 76, 1) for input Tensor("conv2d_3_input:0", shape=(None, 322, 76, 1), dtype=float32), but it was called on an input with incompatible shape (None, 332, 76, 1)

摘要如下


Regressor.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_4 (Conv2D)            (None, 322, 76, 64)       640       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 161, 38, 64)       0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 161, 38, 64)       256       
_________________________________________________________________
flatten_3 (Flatten)          (None, 391552)            0         
_________________________________________________________________
dense_5 (Dense)              (None, 64)                25059392  
_________________________________________________________________
dropout_2 (Dropout)          (None, 64)                0         
_________________________________________________________________
dense_6 (Dense)              (None, 1)                 65        
=================================================================
Total params: 25,060,353
Trainable params: 25,060,225
Non-trainable params: 128

希望对此有任何评论或反馈!

0 个答案:

没有答案