我正在尝试使用scikitlearn 0.22中提供的plot_confusion_matrix
函数。但是,我遇到一个问题,因为我的其中一个框看不到框中的文本值,因为所有文本颜色似乎都设置为与该框相同的值。无论我选择哪个cmap
,都是如此。顺便说一句,该值也是最低的。 example they provide中不会发生这种情况。如何更改方框文本的颜色,以便可以清楚地看到所有值?除非必须,否则我不想像其他许多解决方案中所建议的那样使用seaborn
。
下面是一个可复制的示例。
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix
np.random.seed(3851)
# import some data to play with
bc = datasets.load_breast_cancer()
X = bc.data
y = bc.target
class_names = bc.target_names
# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
np.random.shuffle(y_test)
# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.0001).fit(X_train, y_train)
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
titles_options = [("Confusion matrix, without normalization", None),
("Normalized confusion matrix", 'true')]
for title, normalize in titles_options:
disp = plot_confusion_matrix(classifier, X_test, y_test,
cmap=plt.cm.Blues,
normalize=normalize)
disp.ax_.set_title(title)
print(title)
print(disp.confusion_matrix)
plt.show()
答案 0 :(得分:1)
我相信您遇到了一个错误。我已提交the issue on github,并提供了可能的解决方法。
基本上,根据颜色值是高于还是低于阈值(该阈值应位于cmap范围的中间)来选择文本的颜色。但是我相信阈值的计算方式存在问题,因此归一化示例中的所有值最终都将低于阈值并以较浅的颜色绘制。
如果要临时修复,可以在第96行的scikit-learn .../site-packages/sklearn/metrics/_plot/confusion_matrix.py
安装中修改一个文件,
thresh = cm.min()+(cm.max() - cm.min()) / 2.
而不是thresh = (cm.max() - cm.min()) / 2.
===================================================================
--- metrics/_plot/confusion_matrix.py (date 1576701552905)
+++ metrics/_plot/confusion_matrix.py (date 1576701552905)
@@ -93,7 +93,7 @@
values_format = '.2g'
# print text with appropriate color depending on background
- thresh = (cm.max() - cm.min()) / 2.
+ thresh = cm.min()+(cm.max() - cm.min()) / 2.
for i, j in product(range(n_classes), range(n_classes)):
color = cmap_max if cm[i, j] < thresh else cmap_min
self.text_[i, j] = ax.text(j, i,