我找到了an excellent tutorial on drawing a heatmap for a confusion matrix,但是我想在侧面增加一些佣金和遗漏的错误。
我将尝试解释使用此图像:
这意味着:
我需要在每个包含数字0、6和9的框旁边插入一个数字,该框位于图像右边缘的右侧和图例的左侧
我需要在每个框的上方分别插入一个数字,该框包含13、0和0,位于图像顶部边缘上方,标题下方。
(所以总共6个数字)
这甚至可能吗?我对Python的绘图功能一无所知,因为我是该语言的新手。从我站着的位置看来,这似乎是一项非常艰巨的任务。
答案 0 :(得分:1)
您可以使用刻度线进行此操作。
让我用下面的简单图解介绍这种方法:
from matplotlib import pyplot as plt
ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
for i in range(3):
for j in range(3):
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
plt.show()
我既不会关注刻度线样式,也不会关注颜色,但是知道您可以非常轻松地进行更改。
您可以创建一个Axes
对象,该对象将与ax
共享ax.twiny()
的Y轴。然后,您可以在这个新的Axes
上添加X刻度,它将显示在图的顶部:
from matplotlib import pyplot as plt
ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
for i in range(3):
for j in range(3):
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks([0.5, 1.5, 2.5])
ax2.set_xticklabels([13, 0, 0])
plt.show()
为了显示X轴的刻度,您必须创建一个Axes
对象,该对象与ax
共享ax.twiny()
的Y轴。这看起来似乎违反直觉,但是如果您使用ax.twinx()
,则修改ax2
的X刻度也会修改ax
的X刻度,因为它们实际上是相同的。
然后,您要设置ax2
的X窗口,使其具有三个正方形。
之后,您可以设置刻度线:每一个正方形在水平中心,所以在[0.5, 1.5, 2.5]
处。
最后,您可以设置刻度标签以显示所需的值。
然后,您只需对Y勾进行相同操作即可:
from matplotlib import pyplot as plt
ax = plt.axes()
ax.set_xlim(0, 3)
ax.set_ylim(0, 3)
for i in range(3):
for j in range(3):
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax.fill_between((i, i+1), j, j+1)
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks([0.5, 1.5, 2.5])
ax2.set_xticklabels([13, 0, 0])
ax3 = ax.twinx()
ax3.set_ylim(ax.get_ylim())
ax3.set_yticks([0.5, 1.5, 2.5])
ax3.set_yticklabels([0, 6, 9])
plt.show()
答案 1 :(得分:1)
使用以下修改的功能。想法如下:
y=1.1
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False,
title=None, cmap=plt.cm.Blues):
if not title:
if normalize:
title = 'Normalized confusion matrix'
else:
title = 'Confusion matrix, without normalization'
cm = confusion_matrix(y_true, y_pred)
classes = classes[unique_labels(y_true, y_pred)]
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
fig, ax = plt.subplots(figsize=(6.5,6))
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=classes, yticklabels=classes,
ylabel='True label',
xlabel='Predicted label')
ax.set_title(title, y=1.1)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Adding data to the right
ax2 = ax.twinx()
ax2.set_ylim(ax.get_ylim())
ax2.set_yticks(np.arange(cm.shape[0]))
ax2.set_yticklabels(cm[:, -1])
ax2.tick_params(axis="y", right=False)
# Adding data to the top
ax3 = ax.twiny()
ax3.set_xlim(ax.get_xlim())
ax3.set_xticks(np.arange(cm.shape[0]))
ax3.set_xticklabels(cm[:, 0])
ax3.tick_params(axis="x", top=False)
ax.set_aspect('auto')
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
答案 2 :(得分:0)
一种相当手工的方法将包括以下各项的组合,直到结果令人满意为止:
twinax = ax.twinx().twiny()
twinax.set(xlim=ax.get_xlim(), ylim=ax.get_ylim())
将其范围与原始轴的范围匹配,然后... twinax.set(xticks=ax.get_xticks(), yticks=ax.get_yticks, xticklabels=('0','1','2'), yticklabels = ('0','1','2'))
来设置新轴上的标签,如您的示例中所做的那样(如果您愿意,可以将这两个调用合并在一起)。