在新课程和先前学习的课程上评估预先训练的模型

时间:2019-10-13 10:38:22

标签: tensorflow machine-learning keras

我正在尝试编写一个模型来证明灾难性的遗忘。我正在使用cifar10图像数据集,该数据集具有来自10个不同类的图像。首先,我在8个班级上训练了模型(carcat班除外)。

我现在想通过仅在每个时期展示经过预训练的模型cat来演示灾难性的遗忘,并了解学习新的cat类如何影响旧类的预测准确性。对于每个时期,我想计算旧课程中的训练误差和预测误差。

我是深度学习的新手,可能缺少逻辑流程。这是我正在使用的脚本:

from sklearn.metrics import classification_report, confusion_matrix
from keras_sequential_ascii import sequential_model_to_ascii_printout
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import LearningRateScheduler
from keras.models import load_model
from keras.models import Sequential
from keras.datasets import cifar10
from keras.layers import Dense
from keras.utils import np_utils
import matplotlib.pyplot as plt
import seaborn as sn
import numpy as np
import pandas as pd
import keras
import time

# load the entire dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# load the training dataset of everything except cars (validation testing)
noncar_train_index1 = np.where(y_train!=1)[0]
x_train_main = x_train[noncar_train_index1,:,:,:]
y_train_main = y_train[noncar_train_index1,:]

# load the training dataset of just cats (focused learning)
cat_train_index = np.where(y_train==3)[0]
x_train_cat = x_train[cat_train_index,:,:,:]
y_train_cat = y_train[cat_train_index,:]

# load the training dataset of everything except cats + cars
cat_car_train_index = np.where(y_train!=1)[0]
x_train = x_train[cat_car_train_index,:,:,:]
y_train = y_train[cat_car_train_index,:]
cat_car_train_index = np.where(y_train!=3)[0]
x_train_minuscatcar = x_train[cat_car_train_index,:,:,:]
y_train_minuscatcar = y_train[cat_car_train_index,:]

# load the testing dataset all the images except cars
noncar_test_index = np.where(y_test!=1)[0]
x_test_main = x_test[noncar_test_index,:,:,:]
y_test_main = y_test[noncar_test_index,:]

# load the test dataset for cats
cat_test_index = np.where(y_test==3)[0]
x_test_cat = x_test[cat_test_index,:,:,:]
y_test_cat = y_test[cat_test_index,:]

# load the testing dataset all the images except cars and cats
cat_car_test_index = np.where(y_test!=1)[0]
x_test = x_test[cat_car_test_index,:,:,:]
y_test = y_test[cat_car_test_index,:]
cat_car_test_index = np.where(y_test!=3)[0]
x_test_minuscatcar = x_test[cat_car_test_index,:,:,:]
y_test_minuscatcar = y_test[cat_car_test_index,:]

# update the y label to account for deletion of car images
y_train_main[np.where(y_train_main==2)[0]]=1
catind = np.where(y_train_main==3)[0]
y_train_main[np.where(y_train_main>3)[0]]=y_train_main[np.where(y_train_main>3)[0]]-2
y_train_main[catind]=8
y_test_main[np.where(y_test_main==2)[0]]=1
catind = np.where(y_test_main==3)[0]
y_test_main[np.where(y_test_main>3)[0]]=y_test_main[np.where(y_test_main>3)[0]]-2
y_test_main[catind]=8

# update ylabel for just cats
y_train_cat[np.where(y_train_cat==3)[0]]=8
y_test_cat[np.where(y_test_cat==3)[0]]=8

# update ylabel for except cats + cars
y_train_minuscatcar[np.where(y_train_minuscatcar==2)[0]]=1
y_train_minuscatcar[np.where(y_train_minuscatcar>3)[0]]=y_train_minuscatcar[np.where(y_train_minuscatcar>=3)[0]]-2
y_test_minuscatcar[np.where(y_test_minuscatcar==2)[0]]=1
y_test_minuscatcar[np.where(y_test_minuscatcar>3)[0]]=y_test_minuscatcar[np.where(y_test_minuscatcar>=3)[0]]-2

# number of classes since 
num_classes = 9
y_train_main = np_utils.to_categorical(y_train_main,num_classes)
y_test_main = np_utils.to_categorical(y_test_main,num_classes) 
y_train_cat = np_utils.to_categorical(y_train_cat,num_classes)
y_test_cat = np_utils.to_categorical(y_test_cat,num_classes) 
y_train_minuscatcar = np_utils.to_categorical(y_train_minuscatcar,num_classes)
y_test_minuscatcar = np_utils.to_categorical(y_test_minuscatcar,num_classes) 


# pre processsing
x_train_main = x_train_main.astype('float32')
x_train_minuscatcar = x_train_minuscatcar.astype('float32')
x_train_cat = x_train_cat.astype('float32')
x_test_main = x_test_main.astype('float32')
x_test_minuscatcar = x_test_minuscatcar.astype('float32')
x_test_cat = x_test_cat.astype('float32')
# more pre processing
mean = np.mean(x_train_main,axis=(0,1,2,3))
std = np.std(x_train_main,axis=(0,1,2,3))
x_train_main = (x_train_main-mean)/(std+1e-7)
x_test_main = (x_test_main-mean)/(std+1e-7)
######
mean = np.mean(x_train_minuscatcar,axis=(0,1,2,3))
std = np.std(x_train_minuscatcar,axis=(0,1,2,3))
x_train_minuscatcar = (x_train_minuscatcar-mean)/(std+1e-7)
x_test_minuscatcar = (x_test_minuscatcar-mean)/(std+1e-7)
######
mean = np.mean(x_train_cat,axis=(0,1,2,3))
std = np.std(x_train_cat,axis=(0,1,2,3))
x_train_cat = (x_train_cat-mean)/(std+1e-7)
x_test_cat = (x_test_cat-mean)/(std+1e-7)



# callback to save class by class data
class MetricperClass(keras.callbacks.Callback):
    def __init__(self, x, y):
        self.x = x
        self.y = y if (y.ndim == 1 or y.shape[1] == 1) else np.argmax(y, axis=1)
        self.reports = []

    def on_epoch_end(self, epoch, logs={}):    
        y_hat = np.asarray(self.model.predict(self.x))
        y_hat = np.where(y_hat > 0.5, 1, 0) if (y_hat.shape[1] == 1)  else np.argmax(y_hat, axis=1)
        report = classification_report(self.y,y_hat,output_dict=True)
        self.reports.append(report)
        return

    # Utility method
    def get(self, metrics, of_class):
        return [report[str(of_class)][metrics] for report in self.reports]

#### learning rate scheduler
def lr_schedule(epoch):
    lrate = 0.001
    if epoch > 75:
        lrate = 0.0005
    if epoch > 100:
        lrate = 0.0003
    return lrate

# load previously learned model
premodel = load_model('model_6cnn_pretrained.h5')
# create our new model
model = Sequential()
# set the training false on model
for layer in premodel.layers[:-1]:
    model.add(layer)
model = Sequential([
    model,
    Dense(num_classes, activation='softmax')
])
## just for precaution
for layer in model.layers[:-1]:
    layer.trainable=False
# summarize model.
model.summary()

#data augmentation
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    )
datagen.fit(x_train_cat) # FOCUSED

# training begins
batch_size = 64
opt_rms = keras.optimizers.rmsprop(lr=0.001,decay=1e-6)
# compile the model
model.compile(loss='categorical_crossentropy', optimizer=opt_rms, metrics=['accuracy','mean_squared_error'])
# callback function
metrics_multiclass = MetricperClass(x_test_main,y_test_main)
# note the starting run time
start_time = time.time()
history = model.fit_generator(datagen.flow(x_train_cat, y_train_cat, batch_size=batch_size),\
                    steps_per_epoch=x_train_main.shape[0] // batch_size, epochs=5,\
                    verbose=1,validation_data=(x_test_main,y_test_main),callbacks=[LearningRateScheduler(lr_schedule), metrics_multiclass])
# note the end run time and save it
end_time = time.time()
print("start time - end time: ", str(end_time - start_time))
np.save('focused/runtime_6cnn_focused.npy', end_time - start_time)

# save the model and model weights
model.save_weights('focused/model_weights_6cnn_focused.h5') 
model.save("focused/model_6cnn_focused.h5")
np.save('focused/metricsperclass_6cnn_focused.npy', metrics_multiclass.reports)
# Vizualizing model structure
sequential_model_to_ascii_printout(model)

# save the history 
hist_df = pd.DataFrame(history.history) 
hist_csv_file = 'focused/history_6cnn_focused.csv'
with open(hist_csv_file, mode='w') as f:
    hist_df.to_csv(f)

# evaluate the model's performance 
scores = model.evaluate(x_test_main, y_test_main, batch_size=128, verbose=1)
print('\nTest result: %.3f loss: %.3f' % (scores[1]*100,scores[0]))

# evaluate the model's prediction and save it 
Y_pred = model.predict(x_test_main, verbose=2)
np.save('focused/predicted_y_6cnn_focused.npy', Y_pred)
y_pred = np.argmax(Y_pred, axis=1)

# calculate the confusion matrix for the model
cm = confusion_matrix(np.argmax(y_test_main,axis=1),y_pred)
np.save('focused/confusion_matrix_6cnn_focused.npy', cm)

# Visualizing confusion matrix
df_cm = pd.DataFrame(cm, range(num_classes), range(num_classes))
plt.figure(figsize = (14,10))
sn.set(font_scale=1.4)#for label size
sn.heatmap(df_cm, annot=True,annot_kws={"size": 12})# font size
plt.show()

0 个答案:

没有答案