我是R语言的新手,并使用classification
创建了一个tidymodels
模型,下面是collect_predictions(model)
的结果
collect_predictions(members_final) %>% print()
# A tibble: 19,126 x 6
id .pred_died .pred_survived .row .pred_class died
<chr> <dbl> <dbl> <int> <fct> <fct>
1 train/test split 0.285 0.715 5 survived survived
2 train/test split 0.269 0.731 6 survived survived
3 train/test split 0.298 0.702 7 survived survived
4 train/test split 0.276 0.724 8 survived survived
5 train/test split 0.251 0.749 10 survived survived
6 train/test split 0.124 0.876 18 survived survived
7 train/test split 0.127 0.873 21 survived survived
8 train/test split 0.171 0.829 26 survived survived
9 train/test split 0.158 0.842 30 survived survived
10 train/test split 0.150 0.850 32 survived survived
# … with 19,116 more rows
它与yardstick
函数一起使用:
collect_predictions(members_final) %>%
conf_mat(died, .pred_class)
Truth
Prediction died survived
died 196 7207
survived 90 11633
但是当我将collect_predictions
用管道传输到caret::confusionMatrix()
时,它将无法工作
collect_predictions(members_final) %>%
caret::confusionMatrix(as.factor(died), as.factor(.pred_class))
############## output #################
Error: `data` and `reference` should be factors with the same levels.
Traceback:
1. collect_predictions(members_final) %>% caret::confusionMatrix(as.factor(died),
. as.factor(.pred_class))
2. withVisible(eval(quote(`_fseq`(`_lhs`)), env, env))
3. eval(quote(`_fseq`(`_lhs`)), env, env)
4. eval(quote(`_fseq`(`_lhs`)), env, env)
我不确定这是怎么回事,如何使用插入符号评估来解决呢?
使用插入符号评估的目的是找出阳性/阴性类别。
还有其他方法可以找出肯定/否定类别(levels(df $ class)找出模型中使用的肯定类别是否正确?)
答案 0 :(得分:1)
如果您有预测,例如collect_predictions()
的输出,那么您就不想将其通过插入符号插入到函数中。它不像yardstick函数那样将数据作为第一个参数。而是将参数作为向量传递:
library(caret)
#> Loading required package: lattice
#> Loading required package: ggplot2
data("two_class_example", package = "yardstick")
confusionMatrix(two_class_example$predicted, two_class_example$truth)
#> Confusion Matrix and Statistics
#>
#> Reference
#> Prediction Class1 Class2
#> Class1 227 50
#> Class2 31 192
#>
#> Accuracy : 0.838
#> 95% CI : (0.8027, 0.8692)
#> No Information Rate : 0.516
#> P-Value [Acc > NIR] : <2e-16
#>
#> Kappa : 0.6749
#>
#> Mcnemar's Test P-Value : 0.0455
#>
#> Sensitivity : 0.8798
#> Specificity : 0.7934
#> Pos Pred Value : 0.8195
#> Neg Pred Value : 0.8610
#> Prevalence : 0.5160
#> Detection Rate : 0.4540
#> Detection Prevalence : 0.5540
#> Balanced Accuracy : 0.8366
#>
#> 'Positive' Class : Class1
#>
由reprex package(v0.3.0.9001)于2020-10-21创建
看起来您的变量名将为died
和.pred_class
;您需要将包含预测的数据框保存为对象才能访问它。