来自插入符号中训练数据的ROC曲线

时间:2015-06-30 12:47:14

标签: r r-caret roc

使用R包插入符号,如何根据train()函数的交叉验证结果生成ROC曲线?

说,我做了以下事情:

data(Sonar)
ctrl <- trainControl(method="cv", 
  summaryFunction=twoClassSummary, 
  classProbs=T)
rfFit <- train(Class ~ ., data=Sonar, 
  method="rf", preProc=c("center", "scale"), 
  trControl=ctrl)

训练函数超过一系列mtry参数并计算ROC AUC。我想看看相关的ROC曲线 - 我该怎么做?

注意:如果用于采样的方法是LOOCV,那么rfFit将在rfFit$pred槽中包含一个非空数据帧,这似乎正是我所需要的。但是,我需要为&#34; cv&#34;方法(k倍验证)而不是LOO。

另外:不,曾经包含在以前版本的插入符中的roc函数不是答案 - 这是一个低级函数,如果你不这样做,就不能使用它。 t具有每个交叉验证样本的预测概率。

3 个答案:

答案 0 :(得分:31)

savePredictions = TRUE中只缺少ctrl参数(这也适用于其他重采样方法):

library(caret)
library(mlbench)
data(Sonar)
ctrl <- trainControl(method="cv", 
                     summaryFunction=twoClassSummary, 
                     classProbs=T,
                     savePredictions = T)
rfFit <- train(Class ~ ., data=Sonar, 
               method="rf", preProc=c("center", "scale"), 
               trControl=ctrl)
library(pROC)
# Select a parameter setting
selectedIndices <- rfFit$pred$mtry == 2
# Plot:
plot.roc(rfFit$pred$obs[selectedIndices],
         rfFit$pred$M[selectedIndices])

ROC

也许我错过了一些东西,但一个小问题是train总是估计AUC值略微不同于plot.rocpROC::auc(绝对差值<0.005),尽管{{1}使用twoClassSummary估算AUC。 编辑:我认为这是因为来自pROC::auc的ROC是使用单独的CV集的AUC的平均值,这里我们同时计算所有重采样的AUC以获得总体AUC

更新由于这引起了一些关注,以下是使用train plotROC::geom_roc()的解决方案:

ggplot2

ggplot_roc

答案 1 :(得分:12)

在这里,我修改了@ thei1e的情节,其他人可能会觉得有帮助。

训练模型并进行预测

library(caret)
library(ggplot2)
library(mlbench)
library(plotROC)

data(Sonar)

ctrl <- trainControl(method="cv", summaryFunction=twoClassSummary, classProbs=T,
                     savePredictions = T)

rfFit <- train(Class ~ ., data=Sonar, method="rf", preProc=c("center", "scale"), 
               trControl=ctrl)

# Select a parameter setting
selectedIndices <- rfFit$pred$mtry == 2

更新了ROC曲线图

g <- ggplot(rfFit$pred[selectedIndices, ], aes(m=M, d=factor(obs, levels = c("R", "M")))) + 
  geom_roc(n.cuts=0) + 
  coord_equal() +
  style_roc()

g + annotate("text", x=0.75, y=0.25, label=paste("AUC =", round((calc_auc(g))$AUC, 4)))

enter image description here

答案 2 :(得分:0)

已于2019年更新。这是最简单的方法https://cran.r-project.org/web/packages/MLeval/index.html。从Caret对象中获取最佳参数,然后计算概率,然后计算许多度量和图,包括:ROC曲线,PR曲线,PRG曲线和校准曲线。您可以将来自不同模型的多个对象放入其中以比较结果。

library(MLeval)
library(caret)

data(Sonar)
ctrl <- trainControl(method="cv", 
  summaryFunction=twoClassSummary, 
  classProbs=T)
rfFit <- train(Class ~ ., data=Sonar, 
  method="rf", preProc=c("center", "scale"), 
  trControl=ctrl)

## run MLeval

res <- evalm(rfFit)

## get ROC

res$roc

## get calibration curve

res$cc

## get precision recall gain curve

res$prg

enter image description here

enter image description here

enter image description here