在tidymodel的collect_predictions()上使用caret :: confusionMatrix()时出错,以进行模型评估

时间:2020-10-21 13:13:05

标签: r machine-learning r-caret tidymodels

我是R语言的新手,并使用classification创建了一个tidymodels模型,下面是collect_predictions(model)的结果

collect_predictions(members_final) %>% print()

# A tibble: 19,126 x 6
   id               .pred_died .pred_survived  .row .pred_class died    
   <chr>                 <dbl>          <dbl> <int> <fct>       <fct>   
 1 train/test split      0.285          0.715     5 survived    survived
 2 train/test split      0.269          0.731     6 survived    survived
 3 train/test split      0.298          0.702     7 survived    survived
 4 train/test split      0.276          0.724     8 survived    survived
 5 train/test split      0.251          0.749    10 survived    survived
 6 train/test split      0.124          0.876    18 survived    survived
 7 train/test split      0.127          0.873    21 survived    survived
 8 train/test split      0.171          0.829    26 survived    survived
 9 train/test split      0.158          0.842    30 survived    survived
10 train/test split      0.150          0.850    32 survived    survived
# … with 19,116 more rows

它与yardstick函数一起使用:

collect_predictions(members_final) %>%
  conf_mat(died, .pred_class)

          Truth
Prediction  died survived
  died       196     7207
  survived    90    11633

但是当我将collect_predictions用管道传输到caret::confusionMatrix()时,它将无法工作

collect_predictions(members_final) %>% 
  caret::confusionMatrix(as.factor(died), as.factor(.pred_class))

############## output #################
Error: `data` and `reference` should be factors with the same levels.
Traceback:

1. collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died), 
 .     as.factor(.pred_class))

2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))

3. eval(quote(`_fseq`(`_lhs`)), env, env)

4. eval(quote(`_fseq`(`_lhs`)), env, env)

我不确定这是怎么回事,如何使用插入符号评估来解决呢?

使用插入符号评估的目的是找出阳性/阴性类别。

还有其他方法可以找出肯定/否定类别(levels(df $ class)找出模型中使用的肯定类别是否正确?)

1 个答案:

答案 0 :(得分:1)

如果您有预测,例如collect_predictions()的输出,那么您就不想将其通过插入符号插入到函数中。它不像yardstick函数那样将数据作为第一个参数。而是将参数作为向量传递:

library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
data("two_class_example", package = "yardstick")

confusionMatrix(two_class_example$predicted, two_class_example$truth)
#> Confusion Matrix and Statistics
#> 
#>           Reference
#> Prediction Class1 Class2
#>     Class1    227     50
#>     Class2     31    192
#>                                           
#>                Accuracy : 0.838           
#>                  95% CI : (0.8027, 0.8692)
#>     No Information Rate : 0.516           
#>     P-Value [Acc > NIR] : <2e-16          
#>                                           
#>                   Kappa : 0.6749          
#>                                           
#>  Mcnemar's Test P-Value : 0.0455          
#>                                           
#>             Sensitivity : 0.8798          
#>             Specificity : 0.7934          
#>          Pos Pred Value : 0.8195          
#>          Neg Pred Value : 0.8610          
#>              Prevalence : 0.5160          
#>          Detection Rate : 0.4540          
#>    Detection Prevalence : 0.5540          
#>       Balanced Accuracy : 0.8366          
#>                                           
#>        'Positive' Class : Class1          
#> 

reprex package(v0.3.0.9001)于2020-10-21创建

看起来您的变量名将为died.pred_class;您需要将包含预测的数据框保存为对象才能访问它。