基于Keras CNN创建scikit混淆矩阵

时间:2021-02-16 10:25:50

标签: python tensorflow machine-learning keras scikit-learn

我用包含 15 种不同水果的数据集训练了我的模型。一切正常,准确率通常在 90% 以上。

但是如果我想创建关于我的 CNN 的混淆矩阵,这些值没有任何意义,因为它不能反映高精度,而且我不知道为什么,因为我已经尝试了不同的实现方式。

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential 
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout, BatchNormalization
from tensorflow.keras.preprocessing import image_dataset_from_directory
from sklearn.metrics import classification_report, confusion_matrix

train_ds = image_dataset_from_directory(
    directory='training_data/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256))

validation_ds = image_dataset_from_directory(
    directory='validation_data/',
    labels='inferred',
    label_mode='categorical',
    batch_size=32,
    image_size=(256, 256))

#create model
model = Sequential()

#add 2 convolutional layers
model.add(Conv2D(32, (3, 3), activation="relu", padding="same", input_shape=(256, 256, 3), name = "conv1"))
model.add(Conv2D(32, (3, 3), activation="relu", padding="same", name = "conv2"))

#add the pooling layer
model.add(MaxPooling2D(pool_size=(2,2), name="pooling1"))


#add 2 convolutional layers
model.add(Conv2D(32, (3, 3), activation="relu", padding="same", name = "conv3"))
model.add(Conv2D(32, (3, 3), activation="relu", padding="same", name = "conv4"))

#add the pooling layer
model.add(MaxPooling2D(pool_size=(2,2), name="pooling2"))

#add the flattening layer
model.add(Flatten())

#add the fully connected layer
model.add(Dense(256, activation="relu", name="fc1"))

#add dropout layer
model.add(Dropout(0.2, name="dropout2"))

model.add(Dense(15, activation="softmax")) 
model.compile(loss="categorical_crossentropy", optimizer=SGD(lr=0.00001, momentum=0.9), metrics=["accuracy"])

model.fit(train_ds, epochs=10, verbose=1)

score = model.evaluate(validation_ds, verbose=0)
print(f'Test loss: {score[0]} / Test accuracy: {score[1]}')

Y_pred = tf.concat([x for x, y in validation_ds], axis=0)
y_pred = model.predict(Y_pred)
predicted_categories = tf.argmax(y_pred, axis=1)

y_true = tf.concat([y for x, y in validation_ds], axis=0)
true_categories = tf.argmax(y_true, axis=1)

print(confusion_matrix(true_categories, predicted_categories))
print(classification_report(true_categories, predicted_categories,))

在这里您可以看到当前实现的值。准确率是 80%,因为我降低了测试的 epoch 数量:

[[ 9  3 10  2  5  5  8 19  1 12  5 11  1  6 11]
 [ 9  1  8  5  5  6  7 10  3 15  5 12  3  8  7]
 [11  2 10  5  2  5  7  6  2 15  8 11  9  7  8]
 [13  2  7  5  4  5 10 11  1 11  8  7  4  8  6]
 [ 9  3  7  6  4  9  5 13  3 16  4 11  3  7  3]
 [12  4  6  2  2  4 10 18  2 11  8 10  1 11  7]
 [10  2  8  5  3  9  4  8  4 11  7 11  8 11  7]
 [12  1  9  6  5  5  8 16  0 15  6  8  5  8 12]
 [ 9  2  6  5  2  9  9  8  5  8  3 15  8 12  7]
 [10  2  4  6  5 11  9 12  6  6  7  6  5  6  9]
 [ 1  5  5  9  1  7 13 11  4 11  9 10  7  7  8]
 [10  1  9  1  3  8  8  7  2 13  5  8 11  9 13]
 [ 8  1  8  3  5  6  9 11  5 12  8  8  4  7  5]
 [ 8  2  5  4  3  8 11  7  5 11 17  8  5 12  6]
 [13  2  7  6  3 12  9 12  1 11  7  7  6  4  8]]
              precision    recall  f1-score   support

           0       0.08      0.06      0.07       144
           1       0.01      0.03      0.01        33
           2       0.09      0.09      0.09       109
           3       0.05      0.07      0.06        70
           4       0.04      0.08      0.05        52
           5       0.04      0.04      0.04       109
           6       0.04      0.03      0.03       127
           7       0.14      0.09      0.11       169
           8       0.05      0.11      0.07        44
           9       0.06      0.03      0.04       178
          10       0.08      0.08      0.08       107
          11       0.07      0.06      0.06       143
          12       0.04      0.05      0.04        80
          13       0.11      0.10      0.10       123
          14       0.07      0.07      0.07       117

    accuracy                           0.07      1605
   macro avg       0.06      0.07      0.06      1605
weighted avg       0.07      0.07      0.07      1605

Test loss: 0.9647821187973022 / Test accuracy: 0.813707172870636

Current values

0 个答案:

没有答案
相关问题