我正在使用以下函数来生成混淆矩阵:
def plot_confusion_matrix(cm, classes, normalize=False, cmap=cm.Blues, png_output=None, show=True):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
title='Normalized confusion matrix'
else:
title='Confusion matrix'
f = plt.figure()
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
if png_output is not None:
os.makedirs(png_output, exist_ok=True)
f.savefig(os.path.join(png_output,'confusion_matrix.png'), bbox_inches='tight')
if show:
plt.show()
plt.close(f)
else:
plt.close(f)
当我上几节课时,我会得到一个整齐的图表,像这样:
但是当我有很多课时,我会得到:
我尝试使用与该解决方案Python boxplot matplotlib automatic figure size based on the number of categories相同的方法,但是没有用。
如何让我的混淆矩阵根据上面的箱线图解决方案中的类数来调整其大小?
更新1
包括刻度位置和动态图宽度
def plot_confusion_matrix(y_true,y_pred, classes, normalize=False, cmap=cm.Blues, png_output=None, show=True):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
cm = confusion_matrix(y_true,y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
title='Normalized confusion matrix'
else:
title='Confusion matrix'
# Calculate chart area size
leftmargin = 0.5 # inches
rightmargin = 0.5 # inches
categorysize = 0.5 # inches
figwidth = leftmargin + rightmargin + (len(classes) * categorysize)
f = plt.figure(figsize=(figwidth, figwidth))
# Create an axes instance and ajust the subplot size
ax = f.add_subplot(111)
ax.set_aspect(1)
f.subplots_adjust(left=leftmargin/figwidth, right=1-rightmargin/figwidth, top=0.94, bottom=0.1)
res = ax.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar(res)
ax.set_xticks(range(len(classes)))
ax.set_yticks(range(len(classes)))
ax.set_xticklabels(classes, rotation=45, ha='right')
ax.set_yticklabels(classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
if png_output is not None:
os.makedirs(png_output, exist_ok=True)
f.savefig(os.path.join(png_output,'confusion_matrix.png'), bbox_inches='tight')
if show:
plt.show()
plt.close(f)
else:
plt.close(f)
最好的问候。 克莱森里奥斯(Kleyson Rios)。
答案 0 :(得分:0)
现在通过@ImportanceOfBeingErnest获取正确的图表会有帮助。
在最终代码下方:
def plot_confusion_matrix(cm, classes, normalize=False, cmap=cm.Blues, png_output=None, show=True):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
title='Normalized confusion matrix'
else:
title='Confusion matrix'
# Calculate chart area size
leftmargin = 0.5 # inches
rightmargin = 0.5 # inches
categorysize = 0.5 # inches
figwidth = leftmargin + rightmargin + (len(classes) * categorysize)
f = plt.figure(figsize=(figwidth, figwidth))
# Create an axes instance and ajust the subplot size
ax = f.add_subplot(111)
ax.set_aspect(1)
f.subplots_adjust(left=leftmargin/figwidth, right=1-rightmargin/figwidth, top=0.94, bottom=0.1)
res = ax.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar(res)
ax.set_xticks(range(len(classes)))
ax.set_yticks(range(len(classes)))
ax.set_xticklabels(classes, rotation=45, ha='right')
ax.set_yticklabels(classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
ax.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
if png_output is not None:
os.makedirs(png_output, exist_ok=True)
f.savefig(os.path.join(png_output,'confusion_matrix.png'), bbox_inches='tight')
if show:
plt.show()
plt.close(f)
else:
plt.close(f)