我有一个包含多个类的大型数据集。我的目标是将模型拟合到每个类中,然后预测结果并将其可视化为方面中的每个类。
对于可重现的示例,我使用mtcars
创建了一些基本的东西。这适用于每个类的简单回归模型。
mtcars = data.table(mtcars)
model = mtcars[, list(fit = list(lm(mpg~disp+hp+wt))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)
但是,我想尝试类似下面的内容,但这还不行。这个尝试是一个公式列表,但我也希望向每个数据子集发送不同的模型(一些glms,一些树)。
mtcars = data.table(mtcars)
factors = list(c("disp","wt"), c("disp"), c("hp"))
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+"))))
model = mtcars[, list(fit = list(lm(form))), keyby = cyl]
setkey(mtcars, cyl)
mtcars[model, pred := predict(i.fit[[1]], .SD), by = .EACHI]
ggplot(data = mtcars, aes(x = mpg, y = pred)) + geom_line() + facet_wrap(~cyl)
答案 0 :(得分:4)
以下是我们为每个模型设置predict
作为未评估列表的方法,在data.table
对象中评估它们,gather
输出,并将其传递给ggplot
:
models = quote(list(
predict(lm(form[[1]], .SD)),
predict(lm(form[[2]], .SD)),
predict(lm(form[[3]], .SD))))
d <- mtcars
d[, c("est1", "est2", "est3") := eval(models), by = cyl]
d <- tidyr::gather(d, key = model, value = pred, est1:est3)
library(ggplot2)
ggplot(d, aes(x = mpg, y = pred)) + geom_line() + facet_grid(cyl ~ model)
输出:
答案 1 :(得分:3)
lm()
也接受公式作为字符向量。因此,我只需将form
创建为:
form = lapply(factors, function(x) paste("mpg~", paste(x, collapse="+")))
并且,您需要提供正确的数据(使用内置的特殊符号.SD
对应每个组):
model = mtcars[, list(fit=lapply(form, lm, data=.SD)), keyby=cyl]
对于每个cyl
,form
循环播放,相应的公式每次作为lm
的第一个参数传递给data = .SD
,其中.SD
代表数据子集,它本身就是一个data.table。您可以从vignettes了解更多相关信息。
如果您还想在结果中使用公式,那么:
chform = unlist(form)
model = mtcars[, list(form=chform, fit=lapply(form, lm, data=.SD)), keyby = cyl]
HTH
PS:如果您打算使用data.tables在update()
内使用[...]
,请阅读this post。
答案 2 :(得分:1)
我此刻实际上正在做这个,所以完美的时机。这将是一个很好的回答,但我真的很喜欢它的工作方式。
purrr
有一些非常方便的map
函数,与tibble
中的列表列结合使用时非常流畅。使用你的定义(我不试图优化它)
library(data.table)
mtcars = data.table(mtcars)
factors = list(c("disp","wt"), c("disp"), c("hp"))
form = lapply(factors, function(x) as.formula(paste("mpg~",paste(x,collapse="+"))))
提供了一个函数列表,这些函数可以传递给purrr::invoke_map
,它将一个参数列表(你有)应用到一个函数列表中(在你的例子中只是lm
,但是我使用可选参数(在您的示例中为mtcars
),怀疑可以扩展到其他人。使用tibble
,这些存储为一个整洁的data.frame
- esque list
,否则它们将作为lm
个对象返回
library(tibble)
library(purrr)
models <- tibble(fit = invoke_map(lm, form, data = mtcars))
models
#> # A tibble: 3 x 1
#> fit
#> <list>
#> 1 <S3: lm>
#> 2 <S3: lm>
#> 3 <S3: lm>
当您想要对所有这些元素执行某些操作时,超级有用的部分会出现,例如,提取拟合系数:
map(models$fit, coefficients)
#> [[1]]
#> (Intercept) disp wt
#> 34.96055404 -0.01772474 -3.35082533
#>
#> [[2]]
#> (Intercept) disp
#> 29.59985476 -0.04121512
#>
#> [[3]]
#> (Intercept) hp
#> 30.09886054 -0.06822828
或重新检查使用的公式
map(models$fit, formula)
#> [[1]]
#> mpg ~ disp + wt
#> <environment: 0x0000000017ee73a8>
#>
#> [[2]]
#> mpg ~ disp
#> <environment: 0x0000000018392c58>
#>
#> [[3]]
#> mpg ~ hp
#> <environment: 0x0000000018471d18>
此外,如果您想从模型中添加一些预测,可以使用broom::augment
library(broom)
models_with_predicts <- models %>% mutate(predict = map(fit, augment))
models_with_predicts
#> # A tibble: 3 x 2
#> fit predict
#> <list> <list>
#> 1 <S3: lm> <data.frame [32 x 10]>
#> 2 <S3: lm> <data.frame [32 x 9]>
#> 3 <S3: lm> <data.frame [32 x 9]>
您可以通过unnest()
返回到数据级别(带有预测),但这将合并您的所有数据(添加分组级别以保持拟合分开)
library(tidyr)
unnest(models_with_predicts, predict)
#> # A tibble: 96 x 11
#> mpg disp wt .fitted .se.fit .resid .hat .sigma .cooksd .std.resid hp
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 21.0 160.0 2.620 23.34543 0.6075520 -2.3454326 0.04339369 2.933379 0.010222201 -0.8222164 NA
#> 2 21.0 160.0 2.875 22.49097 0.6221836 -1.4909721 0.04550894 2.954135 0.004351414 -0.5232550 NA
#> 3 22.8 108.0 2.320 25.27237 0.7326015 -2.4723669 0.06309504 2.928665 0.017217431 -0.8757799 NA
#> 4 21.4 258.0 3.215 19.61467 0.5743205 1.7853334 0.03877647 2.948162 0.005241995 0.6243627 NA
#> 5 18.7 360.0 3.440 17.05281 1.0943208 1.6471930 0.14078260 2.949120 0.020275438 0.6092882 NA
#> 6 18.1 225.0 3.460 19.37863 0.6122393 -1.2786309 0.04406584 2.957872 0.003089406 -0.4483953 NA
#> 7 14.3 360.0 3.570 16.61720 0.9897465 -2.3171997 0.11516157 2.931444 0.030948880 -0.8446199 NA
#> 8 24.4 146.7 3.190 21.67120 0.9053245 2.7287988 0.09635365 2.918183 0.034431234 0.9842424 NA
#> 9 22.8 140.8 3.150 21.90981 0.9165259 0.8901898 0.09875274 2.962885 0.003775416 0.3215070 NA
#> 10 19.2 167.6 3.440 20.46305 0.9678618 -1.2630477 0.11012510 2.957375 0.008693734 -0.4590766 NA
#> # ... with 86 more rows