我用 resamples_fit
和 work_flow()
做 V-Fold Cross-Validation
。
我的模型是逻辑回归。
如何使用 V-Fold 交叉验证获得欧洲防风草逻辑回归模型的系数?
如果我的 V-Fold Cross-Validation v=5
,我想得到 5 倍系数。
答案 0 :(得分:2)
您通常不想使用 fit_resamples()
来训练和保留五个模型; fit_resamples()
函数的主要目的是使用 resampling to estimate performance。五个模型适合然后扔掉。
但是,如果您确实有一些用例想要保留适合的模型,例如 in this article,那么您可以使用 extract_model
。
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
data(penguins)
set.seed(2021)
penguin_split <- penguins %>%
filter(!is.na(sex)) %>%
initial_split(strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)
penguin_folds <- vfold_cv(penguin_train, v = 5, strata = sex)
penguin_folds
#> # 5-fold cross-validation using stratification
#> # A tibble: 5 x 2
#> splits id
#> <list> <chr>
#> 1 <split [198/51]> Fold1
#> 2 <split [199/50]> Fold2
#> 3 <split [199/50]> Fold3
#> 4 <split [200/49]> Fold4
#> 5 <split [200/49]> Fold5
glm_spec <- logistic_reg() %>%
set_engine("glm")
glm_rs <- workflow() %>%
add_formula(sex ~ species + bill_length_mm + bill_depth_mm + body_mass_g) %>%
add_model(glm_spec) %>%
fit_resamples(
resamples = penguin_folds,
control = control_resamples(extract = extract_model, save_pred = TRUE)
)
既然您已经在重采样中使用了 extract_model
,它就会出现在您的结果中,并且您拥有每个折叠可用的模型。
glm_rs
#> # Resampling results
#> # 5-fold cross-validation using stratification
#> # A tibble: 5 x 6
#> splits id .metrics .notes .extracts .predictions
#> <list> <chr> <list> <list> <list> <list>
#> 1 <split [198/… Fold1 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [51 × …
#> 2 <split [199/… Fold2 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 3 <split [199/… Fold3 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [50 × …
#> 4 <split [200/… Fold4 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
#> 5 <split [200/… Fold5 <tibble [2 × … <tibble [0 ×… <tibble [1 ×… <tibble [49 × …
glm_rs$.extracts[[1]]
#> # A tibble: 1 x 2
#> .extracts .config
#> <list> <chr>
#> 1 <glm> Preprocessor1_Model1
您可以使用 tidyr 和 broom 函数来获取系数,如果这正是您所需要的。
glm_rs %>%
dplyr::select(id, .extracts) %>%
unnest(cols = .extracts) %>%
mutate(tidied = map(.extracts, tidy)) %>%
unnest(tidied)
#> # A tibble: 30 x 8
#> id .extracts .config term estimate std.error statistic p.value
#> <chr> <list> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 Fold1 <glm> Preprocesso… (Interce… -7.44e+1 12.6 -5.89 3.75e-9
#> 2 Fold1 <glm> Preprocesso… speciesC… -6.59e+0 1.82 -3.61 3.03e-4
#> 3 Fold1 <glm> Preprocesso… speciesG… -7.49e+0 2.54 -2.95 3.18e-3
#> 4 Fold1 <glm> Preprocesso… bill_len… 5.56e-1 0.151 3.67 2.40e-4
#> 5 Fold1 <glm> Preprocesso… bill_dep… 1.72e+0 0.424 4.06 4.83e-5
#> 6 Fold1 <glm> Preprocesso… body_mas… 5.88e-3 0.00130 4.51 6.44e-6
#> 7 Fold2 <glm> Preprocesso… (Interce… -6.87e+1 11.3 -6.06 1.37e-9
#> 8 Fold2 <glm> Preprocesso… speciesC… -5.59e+0 1.75 -3.20 1.39e-3
#> 9 Fold2 <glm> Preprocesso… speciesG… -7.61e+0 2.80 -2.71 6.65e-3
#> 10 Fold2 <glm> Preprocesso… bill_len… 4.88e-1 0.145 3.36 7.88e-4
#> # … with 20 more rows
由 reprex package (v2.0.0) 于 2021 年 6 月 27 日创建