混淆矩阵中不同大小的单元的问题

时间:2019-10-29 12:58:09

标签: matplotlib

使用Python 3.7上的最新matplotlib版本,我尝试绘制一个混淆矩阵(并将其保存为png)。虽然所得图原则上很好,但单元的大小不同,请参见此处: confusion matrix

如屏幕截图所示,实际上只有中间单元格的大小正确,所有其他单元格(在这种情况下,即所有边界单元格)的大小似乎都只有中间单元格的一半甚至四分之一。 / p>

我正在运行的源代码很简单:

import os

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix


def create_save_plotted_confusion_matrix(conf_matrix, expected_labels, basepath):
    ax, title = plot_confusion_matrix(conf_matrix, expected_labels, normalize=False)
    filepath = os.path.join(basepath, '.png')
    plt.savefig(filepath, bbox_inches='tight')


def plot_confusion_matrix(cm, classes, normalize=False, title=None, cmap=plt.cm.Blues):
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    else:
        pass

    fig, ax = plt.subplots()
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()

    return ax, title


if __name__ == '__main__':
    y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
    y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
    confmat = confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])

    create_save_plotted_confusion_matrix(confmat, ["ant", "bird", "cat"], '.')

0 个答案:

没有答案