深度学习模型不适用于新型数据

时间:2020-11-04 15:13:57

标签: python keras deep-learning cnn quickdraw

我从Google的快速绘制npy数据中训练了该模型。以下是我如何预处理数据以及如何训练模型。该测试集的准确性为50%,但考虑到它对345个类别进行分类的效果还是不错的。

# Reshape and normalize
x_train = x_train.reshape(x_train.shape[0], image_size, image_size, 1).astype('float32')
x_test = x_test.reshape(x_test.shape[0], image_size, image_size, 1).astype('float32')
#image_size is 28

x_train /= 255.0
x_test /= 255.0

# Convert class vectors to class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


def cnn_model():
    # create model
    model = Sequential()
    model.add(Conv2D(30, (5, 5), input_shape=x_train.shape[1:], activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(15, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))
    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dense(50, activation='relu'))
    model.add(Dense(num_classes, activation='softmax'))
    # Compile model
    
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

培训过程和评估结果如下。

Epoch 1/100
22356/22356 [==============================] - 1323s 59ms/step - loss: 2.7714 - accuracy: 0.3795 - val_loss: 2.2759 - val_accuracy: 0.4751
Epoch 2/100
22356/22356 [==============================] - 1339s 60ms/step - loss: 2.3925 - accuracy: 0.4481 - val_loss: 2.1659 - val_accuracy: 0.4948
Epoch 3/100
22356/22356 [==============================] - 1323s 59ms/step - loss: 2.3365 - accuracy: 0.4588 - val_loss: 2.1333 - val_accuracy: 0.5015
Epoch 4/100
22356/22356 [==============================] - 1303s 58ms/step - loss: 2.3131 - accuracy: 0.4630 - val_loss: 2.1396 - val_accuracy: 0.4996
Epoch 5/100
22356/22356 [==============================] - 1262s 56ms/step - loss: 2.3013 - accuracy: 0.4655 - val_loss: 2.1199 - val_accuracy: 0.5026
Epoch 6/100
22356/22356 [==============================] - 1326s 59ms/step - loss: 2.2932 - accuracy: 0.4663 - val_loss: 2.1190 - val_accuracy: 0.5046
Epoch 7/100
22356/22356 [==============================] - 1269s 57ms/step - loss: 2.2870 - accuracy: 0.4676 - val_loss: 2.1067 - val_accuracy: 0.5053
Epoch 8/100
22356/22356 [==============================] - 1299s 58ms/step - loss: 2.2844 - accuracy: 0.4678 - val_loss: 2.1090 - val_accuracy: 0.5053
Epoch 9/100
22356/22356 [==============================] - 1288s 58ms/step - loss: 2.2828 - accuracy: 0.4683 - val_loss: 2.1147 - val_accuracy: 0.5045
Epoch 10/100
22356/22356 [==============================] - 1289s 58ms/step - loss: 2.2797 - accuracy: 0.4683 - val_loss: 2.0907 - val_accuracy: 0.5073
Epoch 11/100
22356/22356 [==============================] - 1280s 57ms/step - loss: 2.2784 - accuracy: 0.4690 - val_loss: 2.1087 - val_accuracy: 0.5058
Epoch 12/100
22356/22356 [==============================] - 1262s 56ms/step - loss: 2.2787 - accuracy: 0.4688 - val_loss: 2.1078 - val_accuracy: 0.5035
Epoch 13/100
22356/22356 [==============================] - 1335s 60ms/step - loss: 2.2773 - accuracy: 0.4690 - val_loss: 2.1078 - val_accuracy: 0.5049
Epoch 14/100
22356/22356 [==============================] - 1292s 58ms/step - loss: 2.2789 - accuracy: 0.4687 - val_loss: 2.1239 - val_accuracy: 0.5014
Epoch 15/100
22356/22356 [==============================] - 1277s 57ms/step - loss: 2.2824 - accuracy: 0.4676 - val_loss: 2.1220 - val_accuracy: 0.5016
Epoch 16/100
22356/22356 [==============================] - 1291s 58ms/step - loss: 2.2816 - accuracy: 0.4682 - val_loss: 2.1093 - val_accuracy: 0.5058
CPU times: user 18h 13min 31s, sys: 4h 19min 8s, total: 22h 32min 40s
Wall time: 5h 46min 14s

19407/19407 [==============================] - 101s 5ms/step - loss: 2.1135 - accuracy: 0.5047
Test accuarcy: 50.47%

当我加载由Google快速绘图提供的npy数据时,该预测在我的深度学习模型上工作正常。

data_url = '/content/gdrive/My Drive/Colab Notebooks/img/numpy_bitmap/sun.npy'
example_cat = np.load(data_url)

cat_len = example_cat.shape[0] # number of total image

start_num = 11 

example = example_cat[start_num,:784+start_num]

plt.imshow(example.reshape(28, 28))
example = example.reshape(28,28,1).astype('float32')
example /=255.0
print(example)


import matplotlib.pyplot as plt
from random import randint
%matplotlib inline  

pred = model.predict(np.expand_dims(example, axis=0))[0]
ind = (-pred).argsort()[:5]
print(ind)
latex = [categories_dict[x] for x in ind]
plt.imshow(example.squeeze()) 
print(latex)

这是上面代码的结果,很有意义。 (以下链接显示了确切的结果) enter image description here [“太阳”,“蜘蛛”,“螃蟹”,“海龟”,“动物迁徙”]

然后,我捕获了完全相同的图像并将其另存为png图像。我再次将文件作为NumPy数组加载并进行了预处理,以便可以将其放入模型中以预测该文件属于哪个类别。而且它以某种方式不起作用并且返回完全不同的预测。我尝试使用的每个新png图像都在发生这种情况。

im = cv2.imread('/content/gdrive/My Drive/Colab Notebooks/sun2.PNG', cv2.IMREAD_GRAYSCALE)
resize_img = cv2.resize(im, (28,28), interpolation = cv2.INTER_AREA) 
img_vector = np.asarray(resize_img, dtype="uint8")
img = img_vector.reshape(28,28,1).astype('float32')

import matplotlib.pyplot as plt
from random import randint
%matplotlib inline  

img /= 255.0
pred = model.predict(np.expand_dims(img, axis=0))[0]
ind = (-pred).argsort()[:5]
print(ind)
latex = [categories_dict[x] for x in ind]
plt.imshow(img.squeeze()) 
print(latex)

现在我得到的结果完全相同。 enter image description here [“星”,“扩音器”,“蜘蛛”,“海豚”,“蚊子”]

我假设我对数据的预处理方式有问题,但是我找不到发生这种情况的原因以及我做错了什么。我已经非常努力地解决了这个问题,并四处搜寻,但是没有找到线索。如果有人能帮助我解决这个问题,我将不胜感激。

0 个答案:

没有答案