在CIFAR 10数据集上使用VGG16进行转移学习:很高的培训和测试准确性,但预测错误

时间:2019-04-20 19:38:36

标签: python image-processing deep-learning conv-neural-network transfer-learning

我使用转移学习 cifar10 数据集上训练了 vgg16模型。经过一个纪元后,它达到了大约 89%的培训精度,也达到了 89%的测试精度。但是,使用经过训练的模型来预测除数据集以外的图像标签时,会给出错误的答案。甚至错误地标记了非常清晰的图像。

我尝试将时期增加到20个,这将训练和测试的准确性提高到93-94%左右,并尝试了许多不同的图像。经过训练的模型可以正确地从数据集中预测图像,但是在使用新图像时会遇到麻烦。

#!/usr/bin/env python
# coding: utf-8

# In[1]:

from keras.models import load_model
import numpy as np
from tqdm import tqdm
from keras import models
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.applications.vgg16 import VGG16,preprocess_input
from keras.optimizers import Adam
from keras.models import Sequential, Model
from keras.layers import Dense, Flatten, GlobalAveragePooling2D
import pandas as pd
from keras.utils import np_utils
np.random.seed(123) 


# In[2]:


from keras.datasets import cifar10

(X_train, y_train), (X_test, y_test) = cifar10.load_data()


# In[3]:


print (X_train.shape)
print (X_test.shape)
#print (X_train[:2])

# In[4]:


from matplotlib import pyplot as plt


# In[5]:


X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255


# In[6]:


print (y_train.shape)
print (y_test[:10])



# In[7]:


Y_train = np_utils.to_categorical(y_train, 10)
Y_test = np_utils.to_categorical(y_test, 10)


# In[8]:



# In[9]:


import cv2


# In[24]:


train_set_x= X_train[:500]

train_set_y= Y_train[:500]
test_set_x= X_test[:100]
test_set_y= Y_test[:100]

# In[33]:

plt.imshow(X_test[1])
plt.show()

#train_set_y.shape


# In[27]:


frozen = VGG16 (weights="imagenet", input_shape=(32,32,3), include_top=False)


# In[28]:


frozen.summary()


# In[36]:


trainable = frozen.output
trainable = GlobalAveragePooling2D()(trainable)
#print(trainable.shape)
trainable = Dense(128, activation="relu")(trainable)
trainable = Dense(32, activation="relu")(trainable)
trainable = Dense(10, activation="softmax")(trainable)


# In[37]:


model = Model(inputs=frozen.input, outputs=trainable)


# In[38]:


model.summary()


# In[16]:


model.layers


# In[18]:


for layer in model.layers[:-4]:
    layer.trainable = False


# In[19]:


for layer in model.layers:
    print(layer, layer.trainable)


# In[40]:


learning_rate = 0.0001
opt = Adam(lr=learning_rate)
model.compile(optimizer=opt,
              loss='binary_crossentropy',
              metrics=['accuracy'])


# In[41]:


def evaluate_this_model(model, epochs):

    np.random.seed(1)

    history = model.fit(train_set_x, train_set_y, epochs=epochs)
    results = model.evaluate(test_set_x, test_set_y)

    plt.plot(np.squeeze(history.history["loss"]))
    plt.ylabel('cost')
    plt.xlabel('iterations (per tens)')
    plt.title("Learning rate =" + str(learning_rate))
    plt.show()

    print("\n\nAccuracy on training set is {}".format(history.history["acc"][-1]))
    print("\nAccuracy on test set is {}".format(results[1]))


# In[42]:


train_set_x.shape


evaluate_this_model(model, 1)
model.save('vgg16.h5')

model1=load_model('vgg16.h5')




IMG_SIZE=32
path1='../input/ship.png'
img_data1 = cv2.imread(path1, cv2.IMREAD_COLOR)
img_data1 = cv2.resize(img_data1, (IMG_SIZE, IMG_SIZE))
data1 = img_data1.reshape(-1, IMG_SIZE, IMG_SIZE, 3)
model_out=model1.predict(data1)

if np.argmax(model_out) == 1:
    str_label = 'Automobile'
    print(str_label)
if np.argmax(model_out) == 2:
    str_label = 'Bird'
    print(str_label)
if np.argmax(model_out) == 3:
    str_label = 'Cat'
    print(str_label)
if np.argmax(model_out) == 4:
    str_label = 'Deer'
    print(str_label)
if np.argmax(model_out) == 0:
    str_label = 'Airplane'
    print(str_label)
if np.argmax(model_out) == 5:
    str_label = 'Dog'
    print(str_label)
if np.argmax(model_out) == 6:
    str_label = 'Frog'
    print(str_label)
if np.argmax(model_out) == 7:
    str_label = 'Horse'
    print(str_label)
if np.argmax(model_out) == 8:
    str_label = 'Ship'
    print(str_label)
if np.argmax(model_out) == 9:
    str_label = 'Truck'
    print(str_label)

经过训练的模型即使在一个时期后仍会在数据集图像上正确预测和标记,但在使用新图像时会遇到问题,它会完全给错误的标签。例如:它将非常清晰的船舶图像标记为鹿。其他课程也一样。

1 个答案:

答案 0 :(得分:1)

看来您正在通过除以255来缩放训练和测试数据的颜色。我看不到ship.png会发生这种情况。我建议创建一个函数来执行所有预处理,并确保运行该函数以进行训练,测试和预测,以确保可以对所有图像进行完全相同的清洁。