如何根据GradientBoost结果绘制热图?

时间:2019-03-27 19:16:50

标签: python pandas machine-learning seaborn

我想根据y_predict和y_train的结果打印混淆矩阵的热图。

我有些困惑,我已经查看了热图的熊猫文档,但是仍然不知道如何将其应用于结果。我使用的数据集是关于收入的,具有分类和数值数据。我已经应用了GB分类器,并且得到了结果。 唯一剩下的就是热图。

print(confusion_matrix(y_train,y_pred_train))
print(y_train)

这就是结果

Confusion Matrix:


[[14151   710]
 [ 1844  2831]]
Name: income, Length: 19536, dtype: int64

这是尝试制作热图

import seaborn as sns
class_names = y_train, y_pred_train

def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names, 
    )
    fig = plt.figure(figsize=figsize)
    try:
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return fig

返回了

NameError                                 Traceback (most recent call last)
<ipython-input-36-3bd0e9ee90a4> in <module>()
     18     plt.xlabel('Predicted label')
     19     return fig
---> 20 print(fig)

NameError: name 'fig' is not defined

当我在结果中制作混淆矩阵的热图时,我会缺少什么?

1 个答案:

答案 0 :(得分:0)

您可以使用混淆矩阵和类名称列表作为参数来调用print_confusion_matrix函数:

def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
    df_cm = pd.DataFrame(
        confusion_matrix, index=class_names, columns=class_names, 
    )
    fig = plt.figure(figsize=figsize)
    try:
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
    except ValueError:
        raise ValueError("Confusion matrix values must be integers.")
    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    return fig

confusion_matrix = np.array([[14151, 710], [1844, 2831]])
fig = print_confusion_matrix(confusion_matrix, ['0', '1'])

输出:

enter image description here