ValueError:检查目标时出错:预期预测的形状为(4,),但数组的形状为(1,)

时间:2018-12-27 08:06:09

标签: python keras

我搜索了几个类似的主题,涉及类似的问题。尽管如此,我仍然没有设法解决我的问题,为什么现在我要问社区。 我最终想要做的是使用转移学习来开发模型。我正在使用InceptionV3。我冻结了所有图层,并添加了最后一个密集层以预测4个类。 代码是:

base_model = InceptionV3(input_shape= (img_width, img_height, 3), weights='imagenet', include_top=False)
# # Top Model Block    
u = base_model.output
u = GlobalAveragePooling2D()(u)
u = Dense(256, activation='relu', name='fc1')(u)
u = Dropout(0.5)(u)

predictions = Dense(nb_classes, activation='softmax', name='predictions')(u)

model = Model(base_model.input, predictions)
for layer in base_model.layers:
    layer.trainable = False

这是我的培训代码,

model.fit_generator(train_generator,
                    steps_per_epoch=nb_train_samples // batch_size,
                    epochs=nb_epoch / 5,
                    validation_data=val_generator,
                    validation_steps=nb_validation_samples // batch_size,
                    callbacks=callbacks_list)

我正在使用以下代码进行编译,

model.compile(optimizer='nadam',
              loss='categorical_crossentropy',  
              metrics=['accuracy'])

我的数据扩充代码是这样的,

train_datagen = ImageDataGenerator(rescale=1. / 255,

                                  rotation_range=transformation_ratio,
                                   shear_range=transformation_ratio,
                                   zoom_range=transformation_ratio,
                                   cval=transformation_ratio,
                                   horizontal_flip=True,
                                   vertical_flip=True)

validation_datagen = ImageDataGenerator(rescale=1. / 255)


train_generator = train_datagen.flow(x_train,labels_train,batch_size=batch_size)

val_generator = validation_datagen.flow(x_val,labels_val,batch_size=batch_size)

请帮助我调试此错误。 注意:nb_classes = 4

1 个答案:

答案 0 :(得分:0)

标签的形状是什么?我想您还没有一键编码的labels_val和labels_train。

如果您的标签是:[0,0,1,3,2,2,0 ....],则使用“ sparse_categorical_crossentropy”

否则,在标签上应用“一次性编码”,这将是映射:

http
            .sessionManagement()
                .maximumSessions(2)
                .maxSessionsPreventsLogin(true);