我编写了以下函数以在Python / matplotlib中绘制混淆矩阵:
def print_conf_mat(y_test, y_pred):
conf_mat = confusion_matrix(y_test, y_pred)
classes = ['Normal', 'Type 1', 'Type 2', 'Type 3', 'Type 4', 'Type 5', 'Type 6']
imshow(conf_mat, interpolation='nearest', cmap=cm.Blues)
print('Confusion matrix:')
colorbar()
tick_marks = arange(len(classes))
xticks(tick_marks, classes, rotation=45)
yticks(tick_marks, classes)
fmt = 'd'
thresh = conf_mat.max() / 2.
for i, j in product(range(conf_mat.shape[0]), range(conf_mat.shape[1])):
text(j, i, format(conf_mat[i, j], fmt),
horizontalalignment="center",
color="white" if conf_mat[i, j] > thresh else "black")
xlabel('Predicted label')
ylabel('True label')
tight_layout()
figure()
show()
我得到以下格式的矩阵:
Confusion matrix plotted with matplotlib
我想摆脱显示
的烦人的一行<Figure size 432x288 with 0 Axes>
关于我如何实现这一目标的任何线索?