我正在寻找可以帮助我绘制“混淆矩阵”的人。在大学里,我需要这个作为学期论文。但是我在编程方面经验很少。
在图片中,您可以看到分类报告以及我的y_test
的{{1}}和X_test
的结构。
如果有人可以帮助我,我会很高兴,因为我尝试了很多事情,但我只是没有得到解决方案,只有错误消息。
dtree_predictions
接下来我打印多标签混淆矩阵的度量
X_train, X_test, y_train, y_test = train_test_split(X, Y_profile, test_size = 0.3, random_state = 30)
dtree_model = DecisionTreeClassifier().fit(X_train,y_train)
dtree_predictions = dtree_model.predict(X_test)
print(metrics.classification_report(dtree_predictions, y_test))
precision recall f1-score support
0 1.00 1.00 1.00 222
1 1.00 1.00 1.00 211
2 1.00 1.00 1.00 229
3 0.96 0.97 0.96 348
4 0.89 0.85 0.87 93
5 0.86 0.86 0.86 105
6 0.94 0.93 0.94 116
7 1.00 1.00 1.00 364
8 0.99 0.97 0.98 139
9 0.98 0.99 0.99 159
10 0.97 0.96 0.97 189
11 0.92 0.92 0.92 124
12 0.92 0.92 0.92 119
13 0.95 0.96 0.95 230
14 0.98 0.96 0.97 452
15 0.91 0.96 0.93 210
micro avg 0.96 0.96 0.96 3310
macro avg 0.95 0.95 0.95 3310
weighted avg 0.97 0.96 0.96 3310
samples avg 0.96 0.96 0.96 3310
以及from sklearn.metrics import multilabel_confusion_matrix
multilabel_confusion_matrix(y_test, dtree_predictions)
array([[[440, 0],
[ 0, 222]],
[[451, 0],
[ 0, 211]],
[[433, 0],
[ 0, 229]],
[[299, 10],
[ 15, 338]],
[[559, 14],
[ 10, 79]],
[[542, 15],
[ 15, 90]],
[[539, 8],
[ 7, 108]],
[[297, 0],
[ 1, 364]],
[[522, 4],
[ 1, 135]],
[[500, 1],
[ 3, 158]],
[[468, 8],
[ 5, 181]],
[[528, 10],
[ 10, 114]],
[[534, 9],
[ 9, 110]],
[[420, 9],
[ 12, 221]],
[[201, 19],
[ 9, 433]],
[[433, 9],
[ 19, 201]]])
和y_test
dtree_predictons
答案 0 :(得分:3)
通常,通过热图可视化混淆矩阵。在github中也创建了一个函数,以漂亮地打印混淆矩阵。从中得到启发,我适应了多标签方案,将每个具有二元预测(Y,N)的类添加到矩阵中并通过热图可视化。
在这里,示例从发布的代码中获取了一些输出:
import numpy as np
vis_arr = np.asarray([[[440, 0],
[ 0, 222]],
[[451, 0],
[ 0, 211]],
[[433, 0],
[ 0, 229]],
[[299, 10],
[ 15, 338]],
[[559, 14],
[ 10, 79]],
[[542, 15],
[ 15, 90]],
[[539, 8],
[ 7, 108]],
[[297, 0],
[ 1, 364]],
[[522, 4],
[ 1, 135]],
[[500, 1],
[ 3, 158]],
[[468, 8],
[ 5, 181]],
[[528, 10],
[ 10, 114]],
[[534, 9],
[ 9, 110]],
[[420, 9],
[ 12, 221]],
[[201, 19],
[ 9, 433]],
[[433, 9],
[ 19, 201]]])
labels = ["".join("c" + str(i)) for i in range(0, 16)]
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def print_confusion_matrix(confusion_matrix, axes, class_label, class_names, fontsize=14):
df_cm = pd.DataFrame(
confusion_matrix, index=class_names, columns=class_names,
)
try:
heatmap = sns.heatmap(df_cm, annot=True, fmt="d", cbar=False, ax=axes)
except ValueError:
raise ValueError("Confusion matrix values must be integers.")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
axes.set_xlabel('True label')
axes.set_ylabel('Predicted label')
axes.set_title("Confusion Matrix for the class - " + class_label)
扩展基本的混淆矩阵,绘制出以每个类别为标题的子图网格。 [Y,N]是定义的类标签,可以扩展。
fig, ax = plt.subplots(4, 4, figsize=(12, 7))
for axes, cfs_matrix, label in zip(ax.flatten(), vis_arr, labels):
print_confusion_matrix(cfs_matrix, axes, label, ["Y", "N"])
fig.tight_layout()
plt.show()
输出:
答案 1 :(得分:2)
您可以在ConfusionMatrixDisplay
中使用sklearn.metrics
选项。
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_multilabel_classification
from sklearn.tree import DecisionTreeClassifier
X, y = make_multilabel_classification(n_samples=1000,
n_classes=15, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=42)
tree = DecisionTreeClassifier(random_state=42).fit(X_train, y_train)
y_pred = tree.predict(X_test)
f, axes = plt.subplots(3, 5, figsize=(25, 15))
axes = axes.ravel()
for i in range(15):
disp = ConfusionMatrixDisplay(confusion_matrix(y_test[:, i],
y_pred[:, i]),
display_labels=[0, i])
disp.plot(ax=axes[i], values_format='.4g')
disp.ax_.set_title(f'class {i}')
if i<10:
disp.ax_.set_xlabel('')
if i%5!=0:
disp.ax_.set_ylabel('')
disp.im_.colorbar.remove()
plt.subplots_adjust(wspace=0.10, hspace=0.1)
f.colorbar(disp.im_, ax=axes)
plt.show()