用混淆矩阵评估我的模型

时间:2021-06-10 12:14:54

标签: python tensorflow

我正在尝试复制 Lenet-5 神经网络,并且我想显示我的结果的混淆矩阵来评估我的结果。

这就是我所做的:

# Create the model
model = models.Sequential()
model.add(layers.Conv2D(filters=6, kernel_size=(5,5), activation='relu', input_shape=(28,28,1)))
model.add(layers.MaxPooling2D(pool_size=(2,2)))
model.add(layers.Conv2D(filters=16, kernel_size=(5,5), activation='relu'))
model.add(layers.MaxPooling2D(pool_size=(2,2)))
model.add(layers.Flatten())
model.add(layers.Dense(120,activation='relu'))
model.add(layers.Dense(84,activation='relu'))
model.add(layers.Dense(10,activation='softmax'))

# I categorize the data because I use categorical crossentropy
train_labels = to_categorical(train_labels)
val_labels = to_categorical(val_labels)
test_labels = to_categorical(test_labels)
# Compile the model
model.compile(optimizer=SGD(learning_rate=0.1),
               loss='categorical_crossentropy',
               metrics=['accuracy'])
# Fit the model
history = model.fit(train_images, train_labels,
                    epochs=10, batch_size=128,
                    validation_data=(val_images, val_labels),
                    verbose=2)

从这里开始,我认为(希望)一切都好。不,我想评估我的模型的性能。

首先我绘制准确度。

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epochs'); plt.ylabel('Accuracy')
plt.ylim([0.85, 1])
plt.legend(loc='best')

enter image description here

然后我评估准确率和损失。

test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)
>>> OUT: test_acc: 0.9909999966621399, test_loss: 0.03354883939027786

现在我想看看混淆矩阵,

from sklearn.metrics import confusion_matrix

predictions = model.predict(test_images)
confusion = confusion_matrix(test_labels, predictions.round())

但是我有这个错误:

<块引用>

ValueError: 不支持多标签指示器

我虽然问题是分类测试数据,但不是。有人可以帮助我吗?我现在的目标是尽可能最好地评估我的模型(我是新手),我认为混淆矩阵是个好主意。

非常感谢!

1 个答案:

答案 0 :(得分:2)

model.predict 返回模型输出的向量表示(例如 [0.1, 0.05, 0.0, 0.85],但 confusion_matrix 需要输出的标签/类别(例如 3)。

您可以使用 np.argmax 函数从向量获取预测标签:

predictedLabels = np.argmax(predictions, axis=1)

如果 test_labels 是单热编码的,您可能必须使用相同的方法。

P.S.:查看 ConfusionMatrixDisplay 以很好地展示混淆矩阵