CNN预测不正确

时间:2019-12-05 10:33:16

标签: python tensorflow keras neural-network conv-neural-network

我尝试使用keras对德国交通标志进行分类。当数据集不平衡时,我获得了99%的val_accuracy。然后,我使用分类报告检查了F1分数:

print(classification_report(y_test1, y_pred_bool))

             precision    recall  f1-score   support

          0       1.00      0.96      0.98        54
          1       0.99      0.99      0.99       618
          2       0.98      1.00      0.99       499
          3       0.98      0.99      0.99       341
          4       1.00      0.98      0.99       472
          5       0.98      1.00      0.99       572
          6       0.99      1.00      0.99       176
          7       1.00      0.98      0.99       176
          8       1.00      1.00      1.00        98
          9       1.00      1.00      1.00       294
         10       0.99      1.00      1.00       313
         11       0.96      0.98      0.97        56
         12       0.99      0.99      0.99       553
         13       0.99      0.96      0.97        95
         14       1.00      1.00      1.00        87
         15       0.98      1.00      0.99       104
         16       0.99      1.00      0.99       133
         17       1.00      0.99      0.99        67
         18       0.98      1.00      0.99       345
         19       1.00      1.00      1.00       151
         20       1.00      1.00      1.00        50
         21       1.00      0.98      0.99       153
         22       0.97      0.99      0.98        73
         23       0.99      0.99      0.99       350
         24       0.99      0.96      0.98       104
         25       1.00      1.00      1.00       217
         26       1.00      1.00      1.00        54
         27       1.00      0.99      0.99       165
         28       0.95      0.98      0.96        93
         29       0.99      1.00      0.99       275
         30       0.98      1.00      0.99        95
         31       1.00      1.00      1.00        58
         32       1.00      0.99      1.00       535
         33       0.98      1.00      0.99        62
         34       0.99      1.00      0.99       497
         35       1.00      0.98      0.99       100
         36       1.00      1.00      1.00        65
         37       0.98      0.98      0.98        49
         38       1.00      0.94      0.97       446
         39       0.98      1.00      0.99        90
         40       0.97      1.00      0.99       368
         41       1.00      0.99      1.00       337
         42       1.00      0.99      1.00       363

   accuracy                           0.99      9803
  macro avg       0.99      0.99      0.99      9803
weighted avg       0.99      0.99      0.99      9803

然后,我检查了混淆矩阵,这是正确的。

我将模型保存在磁盘上,并再次加载它,以使用以下代码预测每个类的图像:

model = load_model('/kaggle/working/models-07-0.9904.h5')
pred = model.predict(images1)
print(pred)
y_pred_bool = np.argmax(pred, axis = 1)
print(y_pred_bool)

不幸的是,大多数预测都是错误的,我也不知道为什么。有什么建议吗?

编辑

具有多标签分类

我从每个标签上拍摄了前10张图像并进行了预测

对应的标签是

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42] 

y_pred_bool

array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,
    1,  1,  1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 23, 23, 23, 23,
   23, 23, 23, 23, 23, 23, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 38,
   38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39,
   39, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 41, 41, 41, 41, 41,
   41, 41, 41, 41, 41, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42,  2,  2,
    2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,
    3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,
    5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,
    7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
    9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10,
   10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 13, 13, 13, 13,
   13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15,
   15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16,
   16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18,
   18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20,
   20, 20, 20, 20, 20, 20, 20,  3, 21, 21, 21, 21, 21, 21, 21, 21, 21,
   21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 24, 24, 24, 24, 24, 24,
   24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26,
   26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
   28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29,
   29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31,
   31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33,
   33, 33, 33, 33, 33, 33, 33, 33, 33, 35, 35, 35, 35, 35, 35, 35, 35,
   35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37,
   37, 37, 37, 37, 37])

1 个答案:

答案 0 :(得分:0)

如果要保存和加载模型并在同一数据集上求值,则在保存模型和加载模型之后结果不得更改

示例:我运行了一个简单的模型,并使用model.save保存,并使用keras的load_model进行了加载。您可以从here下载pima-indians-diabetes.data.csv数据集(如果无法在此处下载,则可以从kaggle和其他来源下载)。

构建,评估和保存模型-

%tensorflow_version 1.x
# MLP for Pima Indians Dataset saved to single file
import numpy as np
from numpy import loadtxt
from keras.models import Sequential
from keras.layers import Dense

# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv", delimiter=",")

# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]

# define model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# Model Summary
model.summary()

# Fit the model
model.fit(X, Y, epochs=150, batch_size=10, verbose=0)

# evaluate the model
scores = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))

# save model and architecture to single file
model.save("model.h5")
print("Saved model to disk")

输出-

WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 12)                108       
_________________________________________________________________
dense_2 (Dense)              (None, 8)                 104       
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 9         
=================================================================
Total params: 221
Trainable params: 221
Non-trainable params: 0
_________________________________________________________________
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

accuracy: 75.78%
Saved model to disk

加载和评估模型-

# load and evaluate a saved model
from numpy import loadtxt
from keras.models import load_model

# load model
model = load_model('model.h5')

# summarize model.
model.summary()

# load dataset
dataset = loadtxt("pima-indians-diabetes.csv", delimiter=",")

# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]

# evaluate the model
score = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], score[1]*100))

输出-

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 12)                108       
_________________________________________________________________
dense_2 (Dense)              (None, 8)                 104       
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 9         
=================================================================
Total params: 221
Trainable params: 221
Non-trainable params: 0
_________________________________________________________________
accuracy: 75.78%