我正在使用以下代码绘制混淆矩阵:
labels = test_y.unique()
predictions = chosen_clf.predict(X=test_x)
conf_matrix = confusion_matrix(y_true=test_y, y_pred=predictions, labels=labels)
conf_matrix = pd.DataFrame(conf_matrix, index=labels, columns=labels)
plt.figure()
sn.heatmap(conf_matrix, annot=True)
plt.savefig(r'confusion_matrix.png')
但是,剧情正在出现一些问题: -标签被切掉 -网格不够宽,导致值无法读取,例如2e + 02处于(1,1)位置。
我该如何解决?
答案 0 :(得分:1)
尝试添加:
plt.figure(figsize=(20,20))
在行之前:
sn.heatmap()