如何使用 vfold_cv 获得欧洲防风草多项逻辑回归模型的系数?

时间:2021-06-27 12:02:58

标签: r tidymodels

我用 resamples_fitwork_flow()V-Fold Cross-Validation。 我的模型是逻辑回归。

如何使用 V-Fold 交叉验证获得欧洲防风草逻辑回归模型的系数?

如果我的 V-Fold Cross-Validation v=5,我想得到 5 倍系数。

1 个答案:

答案 0 :(得分:2)

您通常不想使用 fit_resamples() 来训练和保留五个模型; fit_resamples() 函数的主要目的是使用 resampling to estimate performance。五个模型适合然后扔掉。

但是,如果您确实有一些用例想要保留适合的模型,例如 in this article,那么您可以使用 extract_model

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
data(penguins)

set.seed(2021)
penguin_split <- penguins %>%
  filter(!is.na(sex)) %>%
  initial_split(strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)

penguin_folds <- vfold_cv(penguin_train, v = 5, strata = sex)
penguin_folds
#> #  5-fold cross-validation using stratification 
#> # A tibble: 5 x 2
#>   splits           id   
#>   <list>           <chr>
#> 1 <split [198/51]> Fold1
#> 2 <split [199/50]> Fold2
#> 3 <split [199/50]> Fold3
#> 4 <split [200/49]> Fold4
#> 5 <split [200/49]> Fold5

glm_spec <- logistic_reg() %>%
  set_engine("glm") 

glm_rs <- workflow() %>%
  add_formula(sex ~ species + bill_length_mm + bill_depth_mm + body_mass_g) %>%
  add_model(glm_spec) %>%
  fit_resamples(
    resamples = penguin_folds,
    control = control_resamples(extract = extract_model, save_pred = TRUE)
  )

既然您已经在重采样中使用了 extract_model,它就会出现在您的结果中,并且您拥有每个折叠可用的模型。

glm_rs
#> # Resampling results
#> # 5-fold cross-validation using stratification 
#> # A tibble: 5 x 6
#>   splits        id    .metrics       .notes        .extracts     .predictions   
#>   <list>        <chr> <list>         <list>        <list>        <list>         
#> 1 <split [198/… Fold1 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [51 × …
#> 2 <split [199/… Fold2 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 3 <split [199/… Fold3 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 4 <split [200/… Fold4 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
#> 5 <split [200/… Fold5 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …

glm_rs$.extracts[[1]]
#> # A tibble: 1 x 2
#>   .extracts .config             
#>   <list>    <chr>               
#> 1 <glm>     Preprocessor1_Model1

您可以使用 函数来获取系数,如果这正是您所需要的。

glm_rs %>% 
  dplyr::select(id, .extracts) %>%
  unnest(cols = .extracts) %>%
  mutate(tidied = map(.extracts, tidy)) %>%
  unnest(tidied)
#> # A tibble: 30 x 8
#>    id    .extracts .config      term      estimate std.error statistic   p.value
#>    <chr> <list>    <chr>        <chr>        <dbl>     <dbl>     <dbl>     <dbl>
#>  1 Fold1 <glm>     Preprocesso… (Interce… -7.44e+1  12.6         -5.89   3.75e-9
#>  2 Fold1 <glm>     Preprocesso… speciesC… -6.59e+0   1.82        -3.61   3.03e-4
#>  3 Fold1 <glm>     Preprocesso… speciesG… -7.49e+0   2.54        -2.95   3.18e-3
#>  4 Fold1 <glm>     Preprocesso… bill_len…  5.56e-1   0.151        3.67   2.40e-4
#>  5 Fold1 <glm>     Preprocesso… bill_dep…  1.72e+0   0.424        4.06   4.83e-5
#>  6 Fold1 <glm>     Preprocesso… body_mas…  5.88e-3   0.00130      4.51   6.44e-6
#>  7 Fold2 <glm>     Preprocesso… (Interce… -6.87e+1  11.3         -6.06   1.37e-9
#>  8 Fold2 <glm>     Preprocesso… speciesC… -5.59e+0   1.75        -3.20   1.39e-3
#>  9 Fold2 <glm>     Preprocesso… speciesG… -7.61e+0   2.80        -2.71   6.65e-3
#> 10 Fold2 <glm>     Preprocesso… bill_len…  4.88e-1   0.145        3.36   7.88e-4
#> # … with 20 more rows

reprex package (v2.0.0) 于 2021 年 6 月 27 日创建