MobileNet ValueError:检查目标时出错:期望dense_1有4个维度,但得到的数组有形状(24,2)

时间:2017-10-30 09:57:07

标签: python tensorflow deep-learning keras

我正在尝试使用Keras应用程序实现多个网络。这里我附加了一段代码,这段代码适用于ResNet50和VGG16,但是当涉及到MobileNet时,它会产生错误:

  

ValueError:检查目标时出错:期望dense_1有4个维度,但得到的数组有形状(24,2)

我正在使用带有3个通道和批量大小为24的224x224图像,并尝试将它们分为2类,因此错误中提到的数字24是批量大小,但我不确定2号,可能是课程数量。

顺便问一下,是否有人知道我为什么会收到keras.applications.mobilenet的错误?

# basic_model = ResNet50()
# basic_model = VGG16()
basic_model = MobileNet()
classes = list(iter(train_generator.class_indices))
basic_model.layers.pop()
for layer in basic_model.layers[:25]:
    layer.trainable = False
last = basic_model.layers[-1].output
temp = Dense(len(classes), activation="softmax")(last)

fineTuned_model = Model(basic_model.input, temp)
fineTuned_model.classes = classes
fineTuned_model.compile(optimizer=Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
fineTuned_model.fit_generator(
        train_generator,
        steps_per_epoch=3764 // batch_size,
        epochs=100,
        validation_data=validation_generator,
        validation_steps=900 // batch_size)
fineTuned_model.save('mobile_model.h5')

1 个答案:

答案 0 :(得分:1)

从源代码中,我们可以看到您正在弹出Reshape()图层。正是将卷积输出(4D)转换为类张量(2D)的那个。

Source code:

if include_top:
    if K.image_data_format() == 'channels_first':
        shape = (int(1024 * alpha), 1, 1)
    else:
        shape = (1, 1, int(1024 * alpha))

    x = GlobalAveragePooling2D()(x)
    x = Reshape(shape, name='reshape_1')(x)
    x = Dropout(dropout, name='dropout')(x)
    x = Conv2D(classes, (1, 1),
               padding='same', name='conv_preds')(x)
    x = Activation('softmax', name='act_softmax')(x)
    x = Reshape((classes,), name='reshape_2')(x)

但所有keras卷积模型都是以不同的方式使用。如果您需要自己的类数,则应使用include_top=False创建这些模型。这样,模型的最后部分(类部分)将根本不存在,您只需添加自己的图层:

basic_model = MobileNet(include_top=False)
for layer in basic_model.layers:
    layers.trainable=False

furtherOutputs = YourOwnLayers()(basic_model.outputs)

您应该尝试复制keras代码中显示的最后一部分,并使用您自己的类数更改classes。或者尝试从完整模型中弹出3个图层,ReshapeActivationConv2D,将它们替换为您自己的图层。