如何从嵌套的模型数据框中创建一列公式(例如y ~ x
或y ~ log(x)
或...)?
在下面的尝试中,“模型”列包含R平方的最大值的模型。创建一列模型公式的目的是确定每一行中使用的模型。
library(tidyverse)
library(broom)
df <- gapminder::gapminder %>%
select(country, x = year, y = lifeExp) %>%
group_by(country) %>%
nest()
rsq_f <- function(model){summary(model)$r.squared}
best_model <- function(df){
models <- list(
lm(formula = y ~ x, data = df),
lm(formula = y ~ log(x), data = df),
lm(formula = log(y) ~ x, data = df),
lm(formula = log(y) ~ log(x), data = df)
)
R_squared <- map_dbl(models, rsq_f)
best_model_num <- which.max(R_squared)
models[best_model_num][[1]]
}
models <- df %>%
mutate(
model = map(data, best_model),
rsq = map(model, broom::glance) %>% map_dbl("r.squared"),
fun_call = map(model, formula)
)
输出为
> models
# A tibble: 142 x 5
country data model rsq fun_call
<fct> <list> <list> <dbl> <list>
1 Afghanistan <tibble [12 x 2]> <S3: lm> 0.949 <S3: formula>
2 Albania <tibble [12 x 2]> <S3: lm> 0.912 <S3: formula>
3 Algeria <tibble [12 x 2]> <S3: lm> 0.986 <S3: formula>
4 Angola <tibble [12 x 2]> <S3: lm> 0.890 <S3: formula>
5 Argentina <tibble [12 x 2]> <S3: lm> 0.996 <S3: formula>
6 Australia <tibble [12 x 2]> <S3: lm> 0.983 <S3: formula>
7 Austria <tibble [12 x 2]> <S3: lm> 0.994 <S3: formula>
8 Bahrain <tibble [12 x 2]> <S3: lm> 0.968 <S3: formula>
9 Bangladesh <tibble [12 x 2]> <S3: lm> 0.997 <S3: formula>
10 Belgium <tibble [12 x 2]> <S3: lm> 0.995 <S3: formula>
# ... with 132 more rows
我要真正看到模型使用的公式,而不是<S3: formula>
。
答案 0 :(得分:4)
根据RLave的评论,答案只是添加as.character()
:
models <- df %>%
mutate(
model = map(data, best_model),
rsq = map(model, broom::glance) %>% map_dbl("r.squared"),
fun_call = map(model, formula) %>% as.character()
)
给出:
# A tibble: 142 x 5
country data model rsq fun_call
<fct> <list> <list> <dbl> <chr>
1 Afghanistan <tibble [12 x 2]> <S3: lm> 0.949 y ~ log(x)
2 Albania <tibble [12 x 2]> <S3: lm> 0.912 y ~ log(x)
3 Algeria <tibble [12 x 2]> <S3: lm> 0.986 y ~ log(x)
4 Angola <tibble [12 x 2]> <S3: lm> 0.890 y ~ log(x)
5 Argentina <tibble [12 x 2]> <S3: lm> 0.996 y ~ x
6 Australia <tibble [12 x 2]> <S3: lm> 0.983 log(y) ~ x
7 Austria <tibble [12 x 2]> <S3: lm> 0.994 log(y) ~ x
8 Bahrain <tibble [12 x 2]> <S3: lm> 0.968 y ~ log(x)
9 Bangladesh <tibble [12 x 2]> <S3: lm> 0.997 log(y) ~ x
10 Belgium <tibble [12 x 2]> <S3: lm> 0.995 log(y) ~ x
# ... with 132 more rows
答案 1 :(得分:0)
为了使自己更清楚,我将举一个例子作为答案,如果我理解正确的话,您会尝试在公式中添加一列,例如字符串"y ~ x"
。
假设我们有一个简单的lm
:
x <- c(4.17,5.58,5.18,6.11,4.50,4.61,5.17,4.53,5.33,5.14)
y <- c(4.81,4.17,4.41,3.59,5.87,3.83,6.03,4.89,4.32,4.69)
my_lm <- lm(y~ x)
通过查看术语,您具有公式,只是排列不正确:
as.character(my_lm[["terms"]])
# [1] "~" "y" "x"
您只需要重新安排前两项:
paste(as.character(my_lm$terms)[2],as.character(my_lm$terms)[1], as.character(my_lm$terms)[-c(1:2)])
# [1] "y ~ x"
这可以用mutate
分配给一列。
答案 2 :(得分:0)
既然您已经回答了问题,我只是想强调一下使用groupedstats
软件包进行分组分析是多么容易:
# loading needed libraries
library(tidyverse)
# creating `glance` summaries
(results_df <- purrr::pmap_dfr(
.l = list(
data = list(gapminder::gapminder),
grouping.vars = alist(country),
formula = list(
lifeExp ~ year, # formula 1
lifeExp ~ log(year), # formula 2
log(lifeExp) ~ year, # formula 3
log(lifeExp) ~ log(year) # formula 4
),
output = list("glance")
),
.f = groupedstats::grouped_lm,
.id = "formula"
))
#> # A tibble: 568 x 14
#> formula country r.squared adj.r.squared sigma statistic df logLik
#> <chr> <fct> <dbl> <dbl> <dbl> <dbl> <int> <dbl>
#> 1 1 Afghan~ 0.948 0.942 1.22 181. 2 -18.3
#> 2 1 Albania 0.911 0.902 1.98 102. 2 -24.1
#> 3 1 Algeria 0.985 0.984 1.32 662. 2 -19.3
#> 4 1 Angola 0.888 0.877 1.41 79.1 2 -20.0
#> 5 1 Argent~ 0.996 0.995 0.292 2246. 2 -1.17
#> 6 1 Austra~ 0.980 0.978 0.621 481. 2 -10.2
#> 7 1 Austria 0.992 0.991 0.407 1261. 2 -5.16
#> 8 1 Bahrain 0.967 0.963 1.64 291. 2 -21.9
#> 9 1 Bangla~ 0.989 0.988 0.977 930. 2 -15.7
#> 10 1 Belgium 0.995 0.994 0.293 1822. 2 -1.20
#> # ... with 558 more rows, and 6 more variables: AIC <dbl>, BIC <dbl>,
#> # deviance <dbl>, df.residual <int>, p.value <dbl>, significance <chr>
# models with maximum R-squared values
results_df %>%
dplyr::group_by(.data = ., country) %>%
dplyr::filter(.data = ., r.squared == max(r.squared))
#> # A tibble: 142 x 14
#> # Groups: country [142]
#> formula country r.squared adj.r.squared sigma statistic df logLik
#> <chr> <fct> <dbl> <dbl> <dbl> <dbl> <int> <dbl>
#> 1 1 Argent~ 0.996 0.995 0.292 2246. 2 -1.17
#> 2 1 Cambod~ 0.639 0.603 5.63 17.7 2 -36.7
#> 3 1 Ireland 0.984 0.983 0.478 621. 2 -7.07
#> 4 1 Madaga~ 0.995 0.994 0.560 1860. 2 -8.97
#> 5 1 Maurit~ 0.998 0.997 0.408 4290. 2 -5.16
#> 6 1 Switze~ 0.997 0.997 0.215 3823. 2 2.52
#> 7 1 Vietnam 0.989 0.988 1.31 934. 2 -19.2
#> 8 1 Yemen,~ 0.981 0.979 1.59 521. 2 -21.5
#> 9 2 Afghan~ 0.949 0.944 1.20 187. 2 -18.2
#> 10 2 Albania 0.912 0.903 1.96 104. 2 -24.0
#> # ... with 132 more rows, and 6 more variables: AIC <dbl>, BIC <dbl>,
#> # deviance <dbl>, df.residual <int>, p.value <dbl>, significance <chr>
由reprex package(v0.2.0.9000)创建于2018-08-22。