我尝试使用update
中的tidymodels
函数覆盖默认调整值,但是这些值无法更新。
例如,在下面的代码中,我想将min_n
的范围从默认值2更改为40到30到50。但是,min_n
的值保持在2和40。
library(tidymodels)
#> -- Attaching packages --------------------------------------------------------------------------------------------------------------------------- tidymodels 0.1.0 --
#> v broom 0.7.0 v recipes 0.1.13
#> v dials 0.0.8 v rsample 0.0.7
#> v dplyr 1.0.0 v tibble 3.0.3
#> v ggplot2 3.3.2 v tune 0.1.1
#> v infer 0.5.2 v workflows 0.1.2
#> v parsnip 0.1.2 v yardstick 0.0.7
#> v purrr 0.3.4
#> -- Conflicts ------------------------------------------------------------------------------------------------------------------------------ tidymodels_conflicts() --
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x recipes::step() masks stats::step()
rf <- decision_tree(cost_complexity = tune(), tree_depth = tune(), min_n = tune()) %>%
set_mode("classification") %>%
set_engine("rpart")
rf_wf <- workflow() %>%
add_model(rf) %>%
add_formula(class ~ .)
param <- rf %>% parameters()
param %>% update(min_n = min_n(range = c(30L, 50L)))
#> Collection of 3 parameters for tuning
#>
#> id parameter type object class
#> cost_complexity cost_complexity nparam[+]
#> tree_depth tree_depth nparam[+]
#> min_n min_n nparam[+]
rf_grid <- grid_regular(param, levels = 2)
rf_grid
#> # A tibble: 8 x 3
#> cost_complexity tree_depth min_n
#> <dbl> <int> <int>
#> 1 0.0000000001 1 2
#> 2 0.1 1 2
#> 3 0.0000000001 15 2
#> 4 0.1 15 2
#> 5 0.0000000001 1 40
#> 6 0.1 1 40
#> 7 0.0000000001 15 40
#> 8 0.1 15 40
由reprex package(v0.3.0)于2020-07-26创建
答案 0 :(得分:1)
update
方法返回一个新的参数对象-它不会更新就位传递给您的值。您需要做
newparam <- param %>% update(min_n = min_n(range = c(30L, 50L)))
grid_regular(newparam, levels = 2)
# cost_complexity tree_depth min_n
# <dbl> <int> <int>
# 1 0.0000000001 1 30
# 2 0.1 1 30
# 3 0.0000000001 15 30
# 4 0.1 15 30
# 5 0.0000000001 1 50
# 6 0.1 1 50
# 7 0.0000000001 15 50
# 8 0.1 15 50