如何使用“ cross_val_predict”方法在一行中获得预测概率和预测标签以进行多类别预测?

时间:2019-09-09 15:14:08

标签: python machine-learning scikit-learn cross-validation sklearn-pandas

我写下了对多类数据进行分类的代码。

import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import cross_val_predict
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA
from itertools import cycle
import pandas as pd 

##########################################################################################
df = pd.read_csv('merged_Zero_Cor_cleaned.tsv',sep='\t')

X = df.drop(columns='class')
y = df['class']

y_bin = label_binarize(y, classes=[0, 1, 2, 3, 4])
n_classes = y_bin.shape[1]

clf = OneVsRestClassifier(QDA())
y_score = cross_val_predict(clf, X, y, cv=10 ,method='predict_proba')
y_pred = cross_val_predict(clf, X, y, cv=10 )

lw = 2

fpr = dict()
tpr = dict()
roc_auc = dict()

for i in range(n_classes):

    df = pd.DataFrame(y_score[:, i])
    df = df.fillna(0)
    fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], df.T.values[0])
    roc_auc[i] = auc(fpr[i], tpr[i])

colors = cycle(['blue', 'red', 'green','black', 'brown'])

for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=lw)

plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic for multi-class data')
plt.legend(loc="lower right")
plt.show()

##########################################################################################

在上面的代码中检查性能,我在两条不同的行中计算预测的概率得分和标签。

y_score = cross_val_predict(clf, X, y, cv=10 ,method='predict_proba')
y_pred = cross_val_predict(clf, X, y, cv=10 )

这在计算上很昂贵。有什么办法可以让我在一行中获得两个输出。

更新

或者我们如何用这种概率解释类?

      0         1         2         3              4
0      0.0  0.250000  0.250000  0.250000   2.500000e-01
1      0.0  0.000000  0.000000  1.000000   0.000000e+00
2      0.0  0.250000  0.250000  0.250000   2.500000e-01
3      0.0  0.000000  0.333333  0.333333   3.333333e-01
4      0.0  0.000000  0.000000  1.000000   0.000000e+00
5      0.0  0.000000  0.000000  1.000000   8.744693e-23
6      0.0  0.333333  0.333333  0.333333   9.255446e-105

0 个答案:

没有答案