从保存在dplyr的列表列中的confusionMatrix中提取内容

时间:2018-09-27 02:01:22

标签: r dplyr

如下面的代码所示,在交叉验证之后,我试图为每个折叠提取模型指标。我保存了重新采样中的所有预测,将数据按倍数进行分组,计算每组的混淆矩阵,并将混淆矩阵对象另存为列表列cm。现在,我需要从列中保存的对象中提取度量标准信息,例如精度等。我的示例代码如下所示。

library(caret)
iris2 = iris %>% 
    filter(Species != 'setosa') %>%
    mutate(Species = factor(Species))

train.control <- trainControl(method="cv", 
                           number=5,
                           summaryFunction = twoClassSummary,
                           classProbs = TRUE,
                           savePredictions='all')
rf = train(Species~., data=iris2,  method = 'rf',
           metric = 'ROC', trControl=train.control)
rf$pred %>% group_by(Resample) %>%
    do(cm = confusionMatrix(.$pred, .$obs),
       Accuracy = map(cm, ~.x$byClass['Precision'])) 

我收到错误消息:

Error in .x$byClass : $ operator is invalid for atomic vectors

我不知道为什么它不起作用。我的问题是如何修改最后一行以使其起作用?谢谢

1 个答案:

答案 0 :(得分:1)

您可以使用ungroup(),然后只需mutate Accuracy,只需list的每一折页访问unlist()的特定部分即可提取{元素本身。

rf$pred %>% 
  group_by(Resample) %>%
  do(cm = confusionMatrix(.$pred, .$obs)) %>% 
  ungroup() %>% 
  mutate(neg_pred_value = map(cm, ~ .x[["byClass"]][["Neg Pred Value"]]) %>% unlist(),
         accuracy = map(cm, ~ .x[["byClass"]][["Precision"]]) %>% unlist())

使用上面的代码,我们得到以下输出为tibble

# A tibble: 5 x 4
  Resample                    cm neg_pred_value  accuracy
     <chr>                <list>          <dbl>     <dbl>
1    Fold1 <S3: confusionMatrix>      0.9090909 1.0000000
2    Fold2 <S3: confusionMatrix>      1.0000000 1.0000000
3    Fold3 <S3: confusionMatrix>      1.0000000 1.0000000
4    Fold4 <S3: confusionMatrix>      0.8181818 0.8888889
5    Fold5 <S3: confusionMatrix>      1.0000000 0.9090909