使用matplotlib绘制热图

时间:2019-07-17 17:11:31

标签: python matplotlib heatmap

每个人。使用matplotlib绘制热图时出现问题。

简而言之,输出热图数字已“裁剪” 。   我想知道为什么会出现此问题以及如何解决该问题

问题描述如下。

1。环境

该程序是在以下环境中创建的。

matplotlib==3.1.1
numpy==1.16.4
pandas==0.24.2

所有库都通过pip安装在conda虚拟环境中。

2。参考代码

我的代码是根据以下官方文档创建的:Creating annotated heatmaps

3。我的代码

我的代码如下:


    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    from mpl_toolkits.axes_grid1 import make_axes_locatable


    def heatmap(data, ax, fmt='{x:d}'):
    """Plot a heatmap with matplotlib.

    Args:
        data (pandas.DataFrame): the matrix to be labeled.
        ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
        fmt (string, optional): the format of the annotations inside the
            heatmap. This should either use the string format method,
            e.g. "{x:.2f}" or "{x:d}".

    Returns:
        ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
    """

    if not isinstance(data, pd.DataFrame):
        raise TypeError('the input data should be a pandas DataFrame object')

    if data.shape[0] != data.shape[1]:
        raise ValueError('the data should be a square-matrix')

    im = ax.imshow(data.values, cmap='magma', interpolation='nearest')

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.15)
    cbar = plt.colorbar(im, cax=cax,
                        ticks=mpl.ticker.MaxNLocator(nbins=6))
    cbar.ax.tick_params(labelsize=12)
    cbar.outline.set_visible(False)

    ax.set_xticks(list(np.arange(data.shape[1])))
    ax.set_yticks(list(np.arange(data.shape[0])))

    column_labels = list(data.columns.values.astype(str))
    ax.set_xticklabels(column_labels, fontsize=12, rotation='vertical')
    ax.set_yticklabels(column_labels, fontsize=12)

    for spine in ax.spines.values():
        spine.set_visible(False)

    valfmt = mpl.ticker.StrMethodFormatter(fmt)

    # Change the text's color depending on the background.
    text_colors = ['white', '0.15']
    threshold = 0.6

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            use_dark = im.norm(data.iloc[i, j]) > threshold
            im.axes.text(
                j, i, valfmt(data.iloc[i, j], None),
                ha="center", va="center", fontsize=12,
                color=text_colors[int(use_dark)])

    return ax

4。输入数据

输入数据是一个非常简单的pandas DataFrame。


    confusion_df = pd.DataFrame([[89, 4], [5, 80]], index=['dog', 'cat'], columns=['dog', 'cat'])

5。输出

我创建了一个如下的混乱图。


    fig, ax = plt.subplots(figsize=(4, 4))

    heatmap(confusion_df, ax=ax, fmt='{x:d}')

    ax.set_xlabel('Prediction', fontsize=16)
    ax.set_ylabel('Ground truth', fontsize=16)
    fig.tight_layout()

    fig.savefig('confusion_map.png')
    plt.show()

6。结果

结果是非常奇怪的,就像我说的那样,这个数字被“裁剪”了。 The 'cropped' heatmap

7。我想要得到的结果。

正确的输出应为如下所示的正方形。您还可以检查参考文献中2中给出的结果。

The figure I want to get

8。我尝试过的方法

我尝试了以下方法,但是它们不起作用。

(a)将我的代码与参考文档中给出的代码进行了比较(请参阅2)

(b)设置较大的图形尺寸,例如3中的fig, ax = plt.subplots(figsize=(6, 6))

(c)在5中删除fig.tight_layout()

非常感谢!

0 个答案:

没有答案