在keras中微调facenet时global_average_pooling2d出错

时间:2018-06-15 10:58:08

标签: python machine-learning keras conv-neural-network

我正在尝试通过链接facenet_keras来微调Fine-tune InceptionV3 on a new set of classes

我在获取基本模型的最后一层并添加全局空间平均池图层的行上出现错误。我不确定但是错误似乎与GlobalAveragePooling2D()层的输入有关。

如何重塑输出图层以适应GlobalAveragePooling2D()图层。 最后一层看起来像:

Dropout (Dropout)               (None, 1792)         0           AvgPool[0][0]                    
__________________________________________________________________________________________________
Bottleneck (Dense)              (None, 128)          229376      Dropout[0][0]                    
__________________________________________________________________________________________________
Bottleneck_BatchNorm (BatchNorm (None, 128)          384         Bottleneck[0][0]                 

错误是:

ValueError                                Traceback (most recent call last)
<ipython-input-24-4a1c01c5761e> in <module>()
      4 
      5 x = base_model.output
----> 6 x = GlobalAveragePooling2D()(x)
      7 x = Dense(256, activation='relu')(x)
      8 predictions = Dense(12, activation='softmax')(x)

/usr/local/lib/python3.6/dist-packages/keras/engine/topology.py in __call__(self, inputs, **kwargs)
    573                 # Raise exceptions in case the input is not compatible
    574                 # with the input_spec specified in the layer constructor.
--> 575                 self.assert_input_compatibility(inputs)
    576 
    577                 # Collect input shapes to build layer.

/usr/local/lib/python3.6/dist-packages/keras/engine/topology.py in assert_input_compatibility(self, inputs)
    472                                      self.name + ': expected ndim=' +
    473                                      str(spec.ndim) + ', found ndim=' +
--> 474                                      str(K.ndim(x)))
    475             if spec.max_ndim is not None:
    476                 ndim = K.ndim(x)

ValueError: Input 0 is incompatible with layer global_average_pooling2d_2: expected ndim=4, found ndim=2

以下是代码:

train_data_path = 'dataset_cfps/train'
validation_data_path = 'dataset_cfps/validation'
test_data_path = 'test'
#Parametres
img_width, img_height = 200, 200

# path to the model weights files.
weights_path = 'keras-facenet/weights/facenet_keras_weights.h5'
top_model_weights_path = 'keras-facenet/model/facenet_keras.h5'

base_model = load_model(top_model_weights_path)

base_model.summary()

x = base_model.output

x = GlobalAveragePooling2D()(x)

x = Dense(256, activation='relu')(x)

predictions = Dense(12, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)

for layer in base_model.layers:

    layer.trainable = False

0 个答案:

没有答案