我尝试使用keras对德国交通标志进行分类。当数据集不平衡时,我获得了99%的val_accuracy。然后,我使用分类报告检查了F1分数:
print(classification_report(y_test1, y_pred_bool))
precision recall f1-score support
0 1.00 0.96 0.98 54
1 0.99 0.99 0.99 618
2 0.98 1.00 0.99 499
3 0.98 0.99 0.99 341
4 1.00 0.98 0.99 472
5 0.98 1.00 0.99 572
6 0.99 1.00 0.99 176
7 1.00 0.98 0.99 176
8 1.00 1.00 1.00 98
9 1.00 1.00 1.00 294
10 0.99 1.00 1.00 313
11 0.96 0.98 0.97 56
12 0.99 0.99 0.99 553
13 0.99 0.96 0.97 95
14 1.00 1.00 1.00 87
15 0.98 1.00 0.99 104
16 0.99 1.00 0.99 133
17 1.00 0.99 0.99 67
18 0.98 1.00 0.99 345
19 1.00 1.00 1.00 151
20 1.00 1.00 1.00 50
21 1.00 0.98 0.99 153
22 0.97 0.99 0.98 73
23 0.99 0.99 0.99 350
24 0.99 0.96 0.98 104
25 1.00 1.00 1.00 217
26 1.00 1.00 1.00 54
27 1.00 0.99 0.99 165
28 0.95 0.98 0.96 93
29 0.99 1.00 0.99 275
30 0.98 1.00 0.99 95
31 1.00 1.00 1.00 58
32 1.00 0.99 1.00 535
33 0.98 1.00 0.99 62
34 0.99 1.00 0.99 497
35 1.00 0.98 0.99 100
36 1.00 1.00 1.00 65
37 0.98 0.98 0.98 49
38 1.00 0.94 0.97 446
39 0.98 1.00 0.99 90
40 0.97 1.00 0.99 368
41 1.00 0.99 1.00 337
42 1.00 0.99 1.00 363
accuracy 0.99 9803
macro avg 0.99 0.99 0.99 9803
weighted avg 0.99 0.99 0.99 9803
然后,我检查了混淆矩阵,这是正确的。
我将模型保存在磁盘上,并再次加载它,以使用以下代码预测每个类的图像:
model = load_model('/kaggle/working/models-07-0.9904.h5')
pred = model.predict(images1)
print(pred)
y_pred_bool = np.argmax(pred, axis = 1)
print(y_pred_bool)
不幸的是,大多数预测都是错误的,我也不知道为什么。有什么建议吗?
编辑
具有多标签分类
我从每个标签上拍摄了前10张图像并进行了预测
对应的标签是
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42]
y_pred_bool
array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 23, 23, 23, 23,
23, 23, 23, 23, 23, 23, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 38,
38, 38, 38, 38, 38, 38, 38, 38, 38, 39, 39, 39, 39, 39, 39, 39, 39,
39, 39, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 41, 41, 41, 41, 41,
41, 41, 41, 41, 41, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5,
5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10,
10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 13, 13, 13, 13,
13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16,
16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18,
18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20,
20, 20, 20, 20, 20, 20, 20, 3, 21, 21, 21, 21, 21, 21, 21, 21, 21,
21, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 24, 24, 24, 24, 24, 24,
24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26,
26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29,
29, 29, 29, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 31, 31, 31, 31,
31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33,
33, 33, 33, 33, 33, 33, 33, 33, 33, 35, 35, 35, 35, 35, 35, 35, 35,
35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 37, 37, 37, 37, 37,
37, 37, 37, 37, 37])
答案 0 :(得分:0)
如果要保存和加载模型并在同一数据集上求值,则在保存模型和加载模型之后结果不得更改。
示例:我运行了一个简单的模型,并使用model.save保存,并使用keras的load_model进行了加载。您可以从here下载pima-indians-diabetes.data.csv数据集(如果无法在此处下载,则可以从kaggle和其他来源下载)。
构建,评估和保存模型-
%tensorflow_version 1.x
# MLP for Pima Indians Dataset saved to single file
import numpy as np
from numpy import loadtxt
from keras.models import Sequential
from keras.layers import Dense
# load pima indians dataset
dataset = np.loadtxt("/content/pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# define model
model = Sequential()
model.add(Dense(12, input_dim=8, activation='relu'))
model.add(Dense(8, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# compile model
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
# Model Summary
model.summary()
# Fit the model
model.fit(X, Y, epochs=150, batch_size=10, verbose=0)
# evaluate the model
scores = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], scores[1]*100))
# save model and architecture to single file
model.save("model.h5")
print("Saved model to disk")
输出-
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/nn_impl.py:183: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 12) 108
_________________________________________________________________
dense_2 (Dense) (None, 8) 104
_________________________________________________________________
dense_3 (Dense) (None, 1) 9
=================================================================
Total params: 221
Trainable params: 221
Non-trainable params: 0
_________________________________________________________________
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.
accuracy: 75.78%
Saved model to disk
加载和评估模型-
# load and evaluate a saved model
from numpy import loadtxt
from keras.models import load_model
# load model
model = load_model('model.h5')
# summarize model.
model.summary()
# load dataset
dataset = loadtxt("pima-indians-diabetes.csv", delimiter=",")
# split into input (X) and output (Y) variables
X = dataset[:,0:8]
Y = dataset[:,8]
# evaluate the model
score = model.evaluate(X, Y, verbose=0)
print("%s: %.2f%%" % (model.metrics_names[1], score[1]*100))
输出-
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 12) 108
_________________________________________________________________
dense_2 (Dense) (None, 8) 104
_________________________________________________________________
dense_3 (Dense) (None, 1) 9
=================================================================
Total params: 221
Trainable params: 221
Non-trainable params: 0
_________________________________________________________________
accuracy: 75.78%