多类RandomForest分类中每个类别的Sklearn ROC曲线

时间:2019-12-29 21:40:55

标签: python scikit-learn random-forest roc

我具有在其上运行RandomForestClassifier的序列嵌入。我可以获得混乱矩阵和其他一些分数。现在,我想为每个类别绘制一个ROC曲线(根据数据,我有7或8个类别)。我已经尝试过使用roc_auc_score()进行多类学习,但这只给了我一个总分,但是我希望每个类的得分都能够绘制出来。

这是我的代码:

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder, 
from sklearn.utils import class_weight
from sklearn.model_selection import train_test_split
from sgt import Sgt
data = pd.read_csv(r'/home/clones.tsv', sep='\t')
X = data['A']
Y = data['B']
seq = X.map(str) + Y
def split(word):
    return [char for char in word]
sequences = [split(x) for x in seq]
epi = data['epi']
encoder = LabelEncoder()
encoder.fit(epi)
y = encoder.transform(epi)
sgt = Sgt(kappa=10, lengthsensitive=False)
embedding = sgt.fit_transform(corpus=sequences)
X_train, X_test, y_train, y_test = train_test_split(embedding, y, test_size=0.2, random_state=3)

from sklearn.ensemble import RandomForestClassifier

classifier = RandomForestClassifier(n_estimators=500, random_state=0)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)

from sklearn import metrics
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, f1_score, roc_auc_score

print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))
print(accuracy_score(y_test, y_pred))
print(f1_score(y_test,y_pred,average='weighted'))

y_prob = classifier.predict_proba(X_test)
macro_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo", average="macro")
weighted_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo", average="weighted")
macro_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr", average="macro")
weighted_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr",average="weighted")

print("One-vs-One ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
      "(weighted by prevalence)"
      .format(macro_roc_auc_ovo, weighted_roc_auc_ovo))
print("One-vs-Rest ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
      "(weighted by prevalence)"
      .format(macro_roc_auc_ovr, weighted_roc_auc_ovr))

我将不胜感激。

0 个答案:

没有答案