分类数据的混淆矩阵...分类数据按字母顺序排列

时间:2019-05-10 06:34:45

标签: python confusion-matrix

我正在尝试使用python创建混淆矩阵。我的混淆矩阵是正确的,但是我在轴上的分类数据按字母顺序排列。我不希望他们按字母顺序排列

我尝试了代码。

import itertools
import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
#class_names =iris.target_names
class_names ={'No_Dementia','Very_Mild','Mild','Moderate'}

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

# Run classifier, using a model that is too regularized (C too low) to see
# the impact on the results
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)
print(y_pred)
print(class_names)

def plot_confusion_matrix(cm, classes,
                      normalize=False,
                      title='Confusion matrix',
                      cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
else:
    print('Confusion matrix, without normalization')

cm = [[1,0,0,0], [0.026,0.921,0.052,0], [0,0.166,0.792,0.042], [0,0,0,1]]


print(cm)

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
#tick_marks = np.arange(len(classes))
tick_marks = np.arange(4)
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes,rotation=45)


plt.tight_layout()
plt.ylabel('Actual Classes')
plt.xlabel('Predicted Classes')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
                  )

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                  title='Normalized confusion matrix')

plt.show()

我给x的数据没有显示,我希望我的xticks标签像 我定义的class_names = {'No_Dementia','Very_Mild','Mild','Moderate'},但我按输出字母顺序排列它们。

0 个答案:

没有答案