我想根据y_predict和y_train的结果打印混淆矩阵的热图。
我有些困惑,我已经查看了热图的熊猫文档,但是仍然不知道如何将其应用于结果。我使用的数据集是关于收入的,具有分类和数值数据。我已经应用了GB分类器,并且得到了结果。 唯一剩下的就是热图。
print(confusion_matrix(y_train,y_pred_train))
print(y_train)
这就是结果
Confusion Matrix:
[[14151 710]
[ 1844 2831]]
Name: income, Length: 19536, dtype: int64
这是尝试制作热图
import seaborn as sns
class_names = y_train, y_pred_train
def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
df_cm = pd.DataFrame(
confusion_matrix, index=class_names, columns=class_names,
)
fig = plt.figure(figsize=figsize)
try:
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
except ValueError:
raise ValueError("Confusion matrix values must be integers.")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
plt.ylabel('True label')
plt.xlabel('Predicted label')
return fig
返回了
NameError Traceback (most recent call last)
<ipython-input-36-3bd0e9ee90a4> in <module>()
18 plt.xlabel('Predicted label')
19 return fig
---> 20 print(fig)
NameError: name 'fig' is not defined
当我在结果中制作混淆矩阵的热图时,我会缺少什么?
答案 0 :(得分:0)
您可以使用混淆矩阵和类名称列表作为参数来调用print_confusion_matrix
函数:
def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
df_cm = pd.DataFrame(
confusion_matrix, index=class_names, columns=class_names,
)
fig = plt.figure(figsize=figsize)
try:
heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
except ValueError:
raise ValueError("Confusion matrix values must be integers.")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
plt.ylabel('True label')
plt.xlabel('Predicted label')
return fig
confusion_matrix = np.array([[14151, 710], [1844, 2831]])
fig = print_confusion_matrix(confusion_matrix, ['0', '1'])
输出: