在插入符中使用summaryFunction,以实现分类中的自定义性能指标

时间:2018-06-19 14:49:48

标签: r r-caret

可复制的示例:

library(mlbench)
library(caret)
data(Sonar)
set.seed(998)
inTraining <- createDataPartition(Sonar$Class, p = .75, list = FALSE)
training <- Sonar[ inTraining,]
testing  <- Sonar[-inTraining,]

twoClassSummary<- function (data, lev = NULL, model = NULL){
  browser()
  lvls <- levels(data$obs)
  if (length(lvls) > 2) 
    stop(paste("Your outcome has", length(lvls), "levels. The twoClassSummary() function isn't appropriate."))
  requireNamespaceQuietStop("ModelMetrics")
  if (!all(levels(data[, "pred"]) == lvls)) 
    stop("levels of observed and predicted data do not match")
  rocAUC <- ModelMetrics::auc(ifelse(data$obs == lev[2], 0, 
                                     1), data[, lvls[1]])
  out <- c(rocAUC, sensitivity(data[, "pred"], data[, "obs"], 
                               lev[1]), specificity(data[, "pred"], data[, "obs"], lev[2]))
  names(out) <- c("ROC", "Sens", "Spec")
  out
}

gbmGrid <-  expand.grid(interaction.depth = c(1, 5, 9), 
                        n.trees = (1:30)*50, 
                        shrinkage = 0.1,
                        n.minobsinnode = 20)
fitControl <- trainControl(method = "repeatedcv",
                           number = 10,
                           repeats = 10,
                           ## Estimate class probabilities
                           classProbs = TRUE,
                           ## Evaluate performance using 
                           ## the following function
                           summaryFunction = twoClassSummary)

set.seed(825)
gbmFit3 <- train(Class ~ ., data = training, 
                 method = "gbm", 
                 trControl = fitControl, 
                 verbose = FALSE, 
                 tuneGrid = gbmGrid,
                 ## Specify which metric to optimize
                 metric = "ROC")

以上代码由caret教程中的代码组成。我想做的是编写自己的summaryFunction,但是在执行性能计算期间data似乎有错误。

为了能够观察传递给data函数的twoClassSummary对象,我在函数开始处引入了browser()行(请参见上面的可复制示例) )。

现在,在调用train函数之后,我们输入twoClassSummary,其数据对象如下:

Browse[1]> data
   pred obs         M          R rowIndex
1     R   R 0.9885415 0.01145848       69
2     M   M 0.5948173 0.40518271      125
3     M   R 0.1687232 0.83127685      100
4     M   M 0.8889852 0.11101480       90
5     R   M 0.6593955 0.34060452       77
6     R   R 0.4098489 0.59015115      107
7     M   R 0.6498745 0.35012546      147
8     R   M 0.6734522 0.32654777       57
9     R   R 0.5569204 0.44307962       14
10    M   M 0.5084563 0.49154370        8

我发现这很尴尬,因为类概率的argmax与pred类不对应。另外,相应的rowIndex中的实际类与obs(观察到的类)不匹配:

Browse[1]> training$Class[data$rowIndex]
 [1] R M M M M M M R R R

我在这里想念什么?

0 个答案:

没有答案