keras模型具有良好的准确性和val acc,但无法预测甚至预测训练数据

时间:2020-06-10 13:17:33

标签: python tensorflow keras deep-learning

我正在尝试建立计算机视觉来对寺庙图像进行分类,该图像包含16个类别,每个类别大约60-70个图像(很少有只有40个数据的图像)。

这是我的数据传播图:

Plot for amount of data in each class

我正在使用VGG19架构,并稍微修改了最后一层。

def get_base_model():
    model = VGG19(input_shape=(224, 224, 3), weights='imagenet', include_top=False)
    model.layers.pop()
    model.layers.pop()
    model.layers.pop()
    model.outputs = [model.layers[-1].output]
    model.layers[-2].outbound_nodes= []
    x = Conv2D(256, kernel_size=(2,2),strides=2)(model.output)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)    
    x = Conv2D(128, kernel_size=(2,2),strides=1)(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    x = Activation('relu')(x)
    x = Flatten()(x)
    x = Dense(len(class_names), activation='softmax')(x)
    model=Model(model.input,x)

    for layer in model.layers[:22]:
        layer.trainable = False

    return model
optimizer = SGD(lr=0.0001, momentum=0.9)
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(
    train_data_gen,
    steps_per_epoch=int(np.ceil(train_data_gen.n/batch_size)),
    validation_data=val_data_gen,
    epochs=epochs,
    validation_steps=int(np.ceil(val_data_gen.n/batch_size)),
    shuffle=True,
    callbacks=[callback]
)

我使用从目录流出来的图像增强

image_gen_train = ImageDataGenerator(
                    rescale=1./255,
                    #preprocessing_function=preprocess_input,
                    width_shift_range=0.02,
                    height_shift_range=0.02,
                    horizontal_flip=False,
                    fill_mode='nearest'
                    )

train_data_gen = image_gen_train.flow_from_directory(
                                                batch_size=batch_size,
                                                directory=train_dir,
                                                shuffle=True,
                                                target_size=(IMG_SHAPE,IMG_SHAPE),
                                                class_mode='categorical',
                                                )

image_gen_val = ImageDataGenerator(
                                    rescale=1./255
                                    #preprocessing_function=preprocess_input
                                    )

val_data_gen = image_gen_val.flow_from_directory(batch_size=batch_size,
                                                 directory=test_dir,
                                                 target_size=(IMG_SHAPE, IMG_SHAPE),
                                                 class_mode='categorical',
                                                 shuffle=False)

(对不起,我很尴尬) 看到我评论了preprocessing_function参数,因为我尝试了使用和不使用它。

这是我尝试预测火车文件夹中所有数据的方法

folder = 'candi_borobudur'
imgs = [img_to_array(load_img(f'images/train/{folder}/{img}').resize((224, 224))) for img in os.listdir(f'images/train/{folder}')]
class_names = ['candi_borobudur', 'candi_brahu', 'candi_banyunibo', 'candi_cangkuang', 'candi_dieng', 'candi_sambisari', 'candi_kalasan', 'candi_pawon', 'candi_padas', 'candi_prambanan', 'candi_jago', 'candi_jabung', 'candi_muara_takus', 'candi_mendut', 'candi_sewu', 'candi_sari_']
print(len(imgs))
for img in imgs:
    #img = preprocess_input(img)
    img = img/255.
    img = np.expand_dims(img, axis=0)
    result = model.predict(img)
    print(result)
    print(np.argmax(result[0]))
    print(class_names[np.argmax(result[0], axis=-1)])

在大约40个纪元后,该模型获得了相当大的90%验证准确度,但是该模型甚至无法预测火车数据也无法给出准确的预测。当我运行上面的代码时,它可以预测其他类。

我所做的事情:

  • 将VGG19类的preprocessing_input用于培训和 预测。
  • 试图将最后一个激活函数更改为S型,relu和tanh。
  • 试图将优化器更改为adam,nadam并更改学习率。
  • 将最后一层更改为更简单的层(仅辍学和fc)
  • 清理数据,删除不良图像,使用更简单的图像。

我还没有做的事情:

  • 尝试收集更多数据

1 个答案:

答案 0 :(得分:0)

通过阅读以上所有内容,我认为您的网络很好。 您的 load_img 方法很有可能基于PIL。 与OpenCV不同的是,每个像素值都将在[0,1]范围内。您编码要求图像在[0,255]范围内并且为浮点型。

您能否在此行中提供load_img和img_to_array的详细信息。或打印由 img_to_array(load_img load_img

imgs = [img_to_array(load_img(f'images/train/{folder}/{img}').resize((224, 224))) for img in os.listdir(f'images/train/{folder}')]