我正在使用Titanic数据集尝试防风草包。
library(titanic)
library(dplyr)
library(tidymodels)
library(rattle)
library(rpart.plot)
library(RColorBrewer)
train <- titanic_train %>%
mutate(Survived = factor(Survived),
Sex = factor(Sex),
Embarked = factor(Embarked))
test <- titanic_test %>%
mutate(Sex = factor(Sex),
Embarked = factor(Embarked))
spec_obj <-
decision_tree(mode = "classification") %>%
set_engine("rpart")
spec_obj
fit_obj <-
spec_obj %>%
fit(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = train)
fit_obj
fancyRpartPlot(fit_obj$fit)
pred <-
fit_obj %>%
predict(new_data = test)
pred
假设我想在模型函数中添加一些参数。
spec_obj <- update(spec_obj, min_n = 50, cost_complexity = 0)
fit_obj <-
spec_obj %>%
fit(Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked, data = train)
fit_obj
fancyRpartPlot(fit_obj$fit)
有什么方法可以绕过fit()
函数中第二次指定模型和数据集吗?
==============编辑===============
我发现您可以将公式保存在变量中:
f <- as.formula("Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked")
fit_obj <-
spec_obj %>%
fit(f, data = train)
fit_obj
还是,有更好的方法吗?
答案 0 :(得分:0)
我认为最好的方法是创建一个小的包装函数,也许叫做fit_titanic()
:
library(titanic)
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(tidymodels)
#> ── Attaching packages ────────────────────────────────────────────── tidymodels 0.1.0 ──
#> ✓ broom 0.5.5 ✓ recipes 0.1.10
#> ✓ dials 0.0.6 ✓ rsample 0.0.6
#> ✓ ggplot2 3.3.0 ✓ tibble 3.0.1
#> ✓ infer 0.5.1 ✓ tune 0.1.0
#> ✓ parsnip 0.1.0 ✓ workflows 0.1.1
#> ✓ purrr 0.3.4 ✓ yardstick 0.0.6
#> ── Conflicts ───────────────────────────────────────────────── tidymodels_conflicts() ──
#> x purrr::discard() masks scales::discard()
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag() masks stats::lag()
#> x ggplot2::margin() masks dials::margin()
#> x recipes::step() masks stats::step()
train <- titanic_train %>%
mutate(Survived = factor(Survived),
Sex = factor(Sex),
Embarked = factor(Embarked))
spec1 <-
decision_tree(mode = "classification") %>%
set_engine("rpart")
spec1
#> Decision Tree Model Specification (classification)
#>
#> Computational engine: rpart
fit_titanic <- function(spec) {
fit(spec,
Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare + Embarked,
data = train)
}
fit_titanic(spec1)
#> parsnip model object
#>
#> Fit time: 17ms
#> n= 891
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 891 342 0 (0.61616162 0.38383838)
#> 2) Sex=male 577 109 0 (0.81109185 0.18890815)
#> 4) Age>=6.5 553 93 0 (0.83182640 0.16817360) *
#> 5) Age< 6.5 24 8 1 (0.33333333 0.66666667)
#> 10) SibSp>=2.5 9 1 0 (0.88888889 0.11111111) *
#> 11) SibSp< 2.5 15 0 1 (0.00000000 1.00000000) *
#> 3) Sex=female 314 81 1 (0.25796178 0.74203822)
#> 6) Pclass>=2.5 144 72 0 (0.50000000 0.50000000)
#> 12) Fare>=23.35 27 3 0 (0.88888889 0.11111111) *
#> 13) Fare< 23.35 117 48 1 (0.41025641 0.58974359)
#> 26) Embarked=S 63 31 0 (0.50793651 0.49206349)
#> 52) Fare< 10.825 37 15 0 (0.59459459 0.40540541) *
#> 53) Fare>=10.825 26 10 1 (0.38461538 0.61538462)
#> 106) Fare>=17.6 10 3 0 (0.70000000 0.30000000) *
#> 107) Fare< 17.6 16 3 1 (0.18750000 0.81250000) *
#> 27) Embarked=C,Q 54 16 1 (0.29629630 0.70370370) *
#> 7) Pclass< 2.5 170 9 1 (0.05294118 0.94705882) *
spec2 <- update(spec1, min_n = 50, cost_complexity = 0)
fit_titanic(spec2)
#> parsnip model object
#>
#> Fit time: 10ms
#> n= 891
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 891 342 0 (0.61616162 0.38383838)
#> 2) Sex=male 577 109 0 (0.81109185 0.18890815)
#> 4) Age>=6.5 553 93 0 (0.83182640 0.16817360) *
#> 5) Age< 6.5 24 8 1 (0.33333333 0.66666667) *
#> 3) Sex=female 314 81 1 (0.25796178 0.74203822)
#> 6) Pclass>=2.5 144 72 0 (0.50000000 0.50000000)
#> 12) Fare>=23.35 27 3 0 (0.88888889 0.11111111) *
#> 13) Fare< 23.35 117 48 1 (0.41025641 0.58974359)
#> 26) Embarked=S 63 31 0 (0.50793651 0.49206349)
#> 52) Fare< 10.825 37 15 0 (0.59459459 0.40540541) *
#> 53) Fare>=10.825 26 10 1 (0.38461538 0.61538462) *
#> 27) Embarked=C,Q 54 16 1 (0.29629630 0.70370370) *
#> 7) Pclass< 2.5 170 9 1 (0.05294118 0.94705882) *
由reprex package(v0.3.0)于2020-04-30创建