我写下了对多类数据进行分类的代码。
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from itertools import cycle
import pandas as pd
##########################################################################################
df = pd.read_csv('merged_Zero_Cor_cleaned.tsv',sep='\t')
X = df.drop(columns='class')
y = df['class']
y_bin = label_binarize(y, classes=[0, 1, 2, 3, 4])
n_classes = y_bin.shape[1]
clf = OneVsRestClassifier(QDA())
y_score = cross_val_predict(clf, X, y, cv=10 ,method='predict_proba')
y_pred = cross_val_predict(clf, X, y, cv=10 )
lw = 2
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
df = pd.DataFrame(y_score[:, i])
df = df.fillna(0)
fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], df.T.values[0])
roc_auc[i] = auc(fpr[i], tpr[i])
colors = cycle(['blue', 'red', 'green','black', 'brown'])
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=lw,
label='ROC curve of class {0} (area = {1:0.2f})'
''.format(i, roc_auc[i]))
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic for multi-class data')
plt.legend(loc="lower right")
plt.show()
##########################################################################################
在上面的代码中检查性能,我在两条不同的行中计算预测的概率得分和标签。
y_score = cross_val_predict(clf, X, y, cv=10 ,method='predict_proba')
y_pred = cross_val_predict(clf, X, y, cv=10 )
这在计算上很昂贵。有什么办法可以让我在一行中获得两个输出。
更新
或者我们如何用这种概率解释类?
0 1 2 3 4
0 0.0 0.250000 0.250000 0.250000 2.500000e-01
1 0.0 0.000000 0.000000 1.000000 0.000000e+00
2 0.0 0.250000 0.250000 0.250000 2.500000e-01
3 0.0 0.000000 0.333333 0.333333 3.333333e-01
4 0.0 0.000000 0.000000 1.000000 0.000000e+00
5 0.0 0.000000 0.000000 1.000000 8.744693e-23
6 0.0 0.333333 0.333333 0.333333 9.255446e-105