如何在“ caret-xgbDART”的保留样本中生成混淆矩阵?

时间:2018-12-06 15:01:07

标签: r r-caret

我正在使用“ xgbDART”方法来训练我在caret中可用的模型。采样方法是“ repeatedcv”。

是否可以生成内部保持样本的混淆矩阵?我认为打印最终模型就像在“ rf”算法中那样会生成它,但是不会。任何建议都会有所帮助。

1 个答案:

答案 0 :(得分:1)

要在插入符号训练后获得混淆矩阵,只需调用生成的火车caret::confusionMatrix上的object。这是有关声纳数据的示例:

library(mlbench)
library(caret)
library(xgboost)
data(Sonar)
ctrl <- trainControl(method = "repeatedcv", 
                     number = 2,
                     repeats = 2)


grid <- expand.grid(max_depth = 5,
                    nrounds = 500,
                    eta =  .01,
                    colsample_bytree = 0.7,
                    gamma = 0.1,
                    min_child_weight = 1,
                    subsample = .6,
                    rate_drop = c(.1, .3),
                    skip_drop = c(.1, .3))


fit.dart <- train(Class ~ .,
                  data =  Sonar, 
                  method = "xgbDART", 
                  metric = "Accuracy",
                  trControl = ctrl, 
                  tuneGrid = grid)

confusionMatrix(fit.dart)
#output
Cross-Validated (2 fold, repeated 2 times) Confusion Matrix 

(entries are percentual average cell counts across resamples)

          Reference
Prediction    M    R
         M 44.5 13.7
         R  8.9 32.9

 Accuracy (average) : 0.774

为了创建自定义的混淆矩阵(例如,使用自定义阈值并且不对重新采样进行平均,可以在classProbs = TRUE中设置savePredictions = TRUEtrainControl

现在,例如,对合并的保留数据使用截断阈值0.3可以做到:

confusionMatrix(fit.dart$pred$obs,
                factor(ifelse(fit.dart$pred$R > 0.3, "R", "M"), levels = c("M", "R")))
#output
Confusion Matrix and Statistics

          Reference
Prediction   M   R
         M 106 116
         R   8 186

               Accuracy : 0.7019          
                 95% CI : (0.6554, 0.7455)
    No Information Rate : 0.726           
    P-Value [Acc > NIR] : 0.8753          

                  Kappa : 0.4214          
 Mcnemar's Test P-Value : <2e-16          

            Sensitivity : 0.9298          
            Specificity : 0.6159          
         Pos Pred Value : 0.4775          
         Neg Pred Value : 0.9588          
             Prevalence : 0.2740          
         Detection Rate : 0.2548          
   Detection Prevalence : 0.5337          
      Balanced Accuracy : 0.7729          

       'Positive' Class : M