每个人。使用matplotlib绘制热图时出现问题。
简而言之,输出热图数字已“裁剪” 。 我想知道为什么会出现此问题以及如何解决该问题 。
问题描述如下。
1。环境
该程序是在以下环境中创建的。
matplotlib==3.1.1
numpy==1.16.4
pandas==0.24.2
所有库都通过pip安装在conda虚拟环境中。
2。参考代码
我的代码是根据以下官方文档创建的:Creating annotated heatmaps
3。我的代码
我的代码如下:
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
def heatmap(data, ax, fmt='{x:d}'):
"""Plot a heatmap with matplotlib.
Args:
data (pandas.DataFrame): the matrix to be labeled.
ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
fmt (string, optional): the format of the annotations inside the
heatmap. This should either use the string format method,
e.g. "{x:.2f}" or "{x:d}".
Returns:
ax (matplotlib.axes.Axes): axes where the confusion matrix is drawn.
"""
if not isinstance(data, pd.DataFrame):
raise TypeError('the input data should be a pandas DataFrame object')
if data.shape[0] != data.shape[1]:
raise ValueError('the data should be a square-matrix')
im = ax.imshow(data.values, cmap='magma', interpolation='nearest')
divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.15)
cbar = plt.colorbar(im, cax=cax,
ticks=mpl.ticker.MaxNLocator(nbins=6))
cbar.ax.tick_params(labelsize=12)
cbar.outline.set_visible(False)
ax.set_xticks(list(np.arange(data.shape[1])))
ax.set_yticks(list(np.arange(data.shape[0])))
column_labels = list(data.columns.values.astype(str))
ax.set_xticklabels(column_labels, fontsize=12, rotation='vertical')
ax.set_yticklabels(column_labels, fontsize=12)
for spine in ax.spines.values():
spine.set_visible(False)
valfmt = mpl.ticker.StrMethodFormatter(fmt)
# Change the text's color depending on the background.
text_colors = ['white', '0.15']
threshold = 0.6
for i in range(data.shape[0]):
for j in range(data.shape[1]):
use_dark = im.norm(data.iloc[i, j]) > threshold
im.axes.text(
j, i, valfmt(data.iloc[i, j], None),
ha="center", va="center", fontsize=12,
color=text_colors[int(use_dark)])
return ax
4。输入数据
输入数据是一个非常简单的pandas DataFrame。
confusion_df = pd.DataFrame([[89, 4], [5, 80]], index=['dog', 'cat'], columns=['dog', 'cat'])
5。输出
我创建了一个如下的混乱图。
fig, ax = plt.subplots(figsize=(4, 4))
heatmap(confusion_df, ax=ax, fmt='{x:d}')
ax.set_xlabel('Prediction', fontsize=16)
ax.set_ylabel('Ground truth', fontsize=16)
fig.tight_layout()
fig.savefig('confusion_map.png')
plt.show()
6。结果
7。我想要得到的结果。
正确的输出应为如下所示的正方形。您还可以检查参考文献中2中给出的结果。
8。我尝试过的方法
我尝试了以下方法,但是它们不起作用。
(a)将我的代码与参考文档中给出的代码进行了比较(请参阅2)
(b)设置较大的图形尺寸,例如3中的fig, ax = plt.subplots(figsize=(6, 6))
(c)在5中删除fig.tight_layout()
非常感谢!