数据增强/提高准确性的速度如何

时间:2020-06-15 11:54:45

标签: python keras deep-learning medical fast-ai

我正在研究一个项目(糖尿病性视网膜病变的分类),并且使用了一个小的数据集(3500张图像),当我训练模型时,发现准确性非常低(50% )。现在,我正在训练使用同一数据集将其至少提高到90%。 我曾尝试使用fastai v1进行数据增强,但对我而言不起作用。 这是我的代码:

from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.models import Sequential, Model
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import optimizers
from keras.applications import VGG16

import cv2


import os
import numpy as np
import itertools
import random
from collections import Counter
from glob import iglob

import warnings
warnings.filterwarnings('ignore')


from sklearn.metrics import confusion_matrix, classification_report

import matplotlib.pyplot as plt
%matplotlib inline 

BASE_DATASET_FOLDER ='/content/drive/My Drive/diabetic-retinopathy-master/diabetic-retinopathy-master.zip (Unzipped Files)/diabetic-retinopathy-master/data'
TRAIN_FOLDER ='/content/drive/My Drive/diabetic-retinopathy-master/diabetic-retinopathy-master.zip (Unzipped Files)/diabetic-retinopathy-master/data/training' 
VALIDATION_FOLDER ='/content/drive/My Drive/diabetic-retinopathy-master/diabetic-retinopathy-master.zip (Unzipped Files)/diabetic-retinopathy-master/data/validation'
TEST_FOLDER ='/content/drive/My Drive/diabetic-retinopathy-master/diabetic-retinopathy-master.zip (Unzipped Files)/diabetic-retinopathy-master/data/test'

IMAGE_SIZE = (224, 224)
INPUT_SHAPE = (224, 224, 3) 

TRAIN_BATCH_SIZE = 80 
VAL_BATCH_SIZE = 15
EPOCHS = 50
LEARNING_RATE = 0.0001 
MODEL_PATH = os.path.join("Diabetic_retinopathy_detection.h5")
MODEL_WEIGHTS_PATH=os.path.join("model_weights.h5")

train_datagen = ImageDataGenerator(
        rescale=1./255,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

train_generator = train_datagen.flow_from_directory(
        os.path.join(BASE_DATASET_FOLDER, TRAIN_FOLDER),
        target_size=IMAGE_SIZE,
        batch_size=TRAIN_BATCH_SIZE,
        class_mode='categorical',
        shuffle=True)

val_datagen = ImageDataGenerator(
    rescale=1./255, )
val_generator = val_datagen.flow_from_directory(
        os.path.join(BASE_DATASET_FOLDER, VALIDATION_FOLDER),
        target_size=IMAGE_SIZE,
        class_mode='categorical', 
        shuffle=False) 

test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = test_datagen.flow_from_directory(
        os.path.join(BASE_DATASET_FOLDER, TEST_FOLDER),
        target_size=IMAGE_SIZE,
        batch_size=VAL_BATCH_SIZE,
        class_mode='categorical', 
        shuffle=False)

classes = {v: k for k, v in train_generator.class_indices.items()}
print(classes)

vgg_model = VGG16(weights='imagenet', include_top=False, input_shape=INPUT_SHAPE)

for layer in vgg_model.layers[:-4]:
layer.trainable = False

model = Sequential()

model.add(vgg_model)

model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(len(classes), activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.Adam(lr=LEARNING_RATE),
              metrics=['acc'])

history = model.fit_generator(
        train_generator,
        steps_per_epoch=train_generator.samples//train_generator.batch_size,
        epochs=EPOCHS,
        validation_data=val_generator,
        validation_steps=val_generator.samples//val_generator.batch_size)

model.save(MODEL_PATH)
model.save_weights(MODEL_WEIGHTS_PATH)

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'b', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

loss, accuracy = model.evaluate_generator(test_generator,steps=test_generator.samples//test_generator.batch_size)
Y_pred = model.predict_generator(test_generator,verbose=1, steps=test_generator.samples//test_generator.batch_size+1)
y_pred = np.argmax(Y_pred, axis=1)
cnf_matrix = confusion_matrix(test_generator.classes, y_pred)

def plot_confusion_matrix(cm, classes,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):

    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.figure(figsize=(12,12))

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, fontsize=18)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, fontsize=8)
    plt.yticks(tick_marks, classes, fontsize=12)

    fmt = '.2f'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center", 
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label', fontsize=16)
    plt.xlabel('Predicted label', fontsize=16)

plot_confusion_matrix(cnf_matrix, list(classes.values()))

print(classification_report(test_generator.classes, y_pred, target_names=list(classes.values())))

def load_image(filename):
    img = cv2.imread(os.path.join(BASE_DATASET_FOLDER, TEST_FOLDER, filename)) 
    img = cv2.resize(img, (IMAGE_SIZE[0], IMAGE_SIZE[1]) )
    img = img /255

    return img


def predict(image):
    probabilities = model.predict(np.asarray([img]))[0]
    class_idx = np.argmax(probabilities)

    return {classes[class_idx]: probabilities[class_idx]}


for idx, filename in enumerate(random.sample(test_generator.filenames, 10)):
    print("SOURCE: class: %s, file: %s" % (os.path.split(filename)[0], filename))

    img = load_image(filename)#importer une image
    prediction = predict(img)#predire l'image
    print("PREDICTED: class: %s, confidence: %f" % (list(prediction.keys())[0], list(prediction.values())[0]))
    plt.imshow(img)
    plt.figure(idx)     
    plt.show()

0 个答案:

没有答案