按组快速线性回归

时间:2015-04-22 16:47:06

标签: r dplyr lm

我有 500K 用户,我需要为每个用户计算线性回归(带截距)

每个用户大约有30条记录。

我尝试使用dplyrlm,这太慢了。 用户大约2秒。

  df%>%                       
      group_by(user_id, add =  FALSE) %>%
      do(lm = lm(Y ~ x, data = .)) %>%
      mutate(lm_b0 = summary(lm)$coeff[1],
             lm_b1 = summary(lm)$coeff[2]) %>%
      select(user_id, lm_b0, lm_b1) %>%
      ungroup()
    )

我尝试使用已知速度更快的lm.fit,但它似乎与dplyr不兼容。

是否有快速的方法按组进行线性回归?

4 个答案:

答案 0 :(得分:11)

如果您想要的只是系数,我只会使用user_id作为回归中的一个因素。使用@ miles2know的模拟数据代码(虽然重命名,因为共享该名称的exp()以外的对象看起来很奇怪)

dat <- data.frame(id = rep(c("a","b","c"), each = 20),
                  x = rnorm(60,5,1.5),
                  y = rnorm(60,2,.2))

mod = lm(y ~ x:id + id + 0, data = dat)

我们不适用全局拦截(+ 0),因此每个ID的截距是id系数,而x本身没有,因此x:id相互作用是每个id的斜率:

coef(mod)
#      ida      idb      idc    x:ida    x:idb    x:idc 
# 1.779686 1.893582 1.946069 0.039625 0.033318 0.000353 

因此,对于aidida系数(1.78)是截距,x:ida系数(0.0396)是斜率。

我会将这些系数的收集留给您的数据框的相应列...

此解决方案应该非常快,因为您不必处理数据帧的子集。使用fastLm等可能会加快速度。

关于可伸缩性的注意事项:

我在@nrussell的模拟全尺寸数据上尝试了这个,并遇到了内存分配问题。根据您拥有的内存量,它可能无法一次性运行,但您可以在批量用户ID中执行此操作。他的答案和我的答案的某些组合可能是最快的整体 - 或者nrussell可能会更快 - 将用户id因子扩展为数千个虚拟变量可能不具备计算效率,因为我一直等待的不仅仅是现在几分钟就可以运行5000个用户ID。

答案 1 :(得分:8)

<强>更新 正如Dirk所指出的那样,通过直接指定xY而不是使用基于公式的fastLm接口,可以大大改善我的原始方法,这会导致(相当重要)处理开销。为了比较,使用原始的完整大小数据集,

R> system.time({
  dt[,c("lm_b0", "lm_b1") := as.list(
    unname(fastLm(x, Y)$coefficients))
    ,by = "user_id"]
})
#  user  system elapsed 
#55.364   0.014  55.401 
##
R> system.time({
  dt[,c("lm_b0","lm_b1") := as.list(
    unname(fastLm(Y ~ x, data=.SD)$coefficients))
    ,by = "user_id"]
})
#   user  system elapsed 
#356.604   0.047 356.820

这个简单的变化产生了大约 6.5倍的加速

[原创方法]

可能还有一些改进空间,但是在Linux VM(2.6 GHz处理器)上运行64位R后,大约需要25分钟:

library(data.table)
library(RcppArmadillo)
##
dt[
  ,c("lm_b0","lm_b1") := as.list(
    unname(fastLm(Y ~ x, data=.SD)$coefficients)),
  by=user_id]
##
R> dt[c(1:2, 31:32, 61:62),]
   user_id   x         Y     lm_b0    lm_b1
1:       1 1.0 1674.8316 -202.0066 744.6252
2:       1 1.5  369.8608 -202.0066 744.6252
3:       2 1.0  463.7460 -144.2961 374.1995
4:       2 1.5  412.7422 -144.2961 374.1995
5:       3 1.0  513.0996  217.6442 261.0022
6:       3 1.5 1140.2766  217.6442 261.0022

数据:

dt <- data.table(
  user_id = rep(1:500000,each=30))
##
dt[, x := seq(1, by=.5, length.out=30), by = user_id]
dt[, Y := 1000*runif(1)*x, by = user_id]
dt[, Y := Y + rnorm(
  30, 
  mean = sample(c(-.05,0,0.5)*mean(Y),1), 
  sd = mean(Y)*.25), 
  by = user_id]

答案 2 :(得分:6)

你可以尝试使用像这样的data.table。我刚刚创建了一些玩具数据,但我想想data.table会有所改进。这很快。但这是一个相当大的数据集,因此可能会在较小的样本上对此方法进行基准测试,以确定速度是否更好。祝好运。


    library(data.table)

    exp <- data.table(id = rep(c("a","b","c"), each = 20), x = rnorm(60,5,1.5), y = rnorm(60,2,.2))
    # edit: it might also help to set a key on id with such a large data-set
    # with the toy example it would make no diff of course
    exp <- setkey(exp,id)
    # the nuts and bolts of the data.table part of the answer
    result <- exp[, as.list(coef(lm(y ~ x))), by=id]
    result
       id (Intercept)            x
    1:  a    2.013548 -0.008175644
    2:  b    2.084167 -0.010023549
    3:  c    1.907410  0.015823088

答案 3 :(得分:1)

使用Rfast的示例。

假设单个响应和500K预测变量。

system.time( Rfast::mvbetas(x,y) )  ## 0.60 seconds

假定500K响应变量和单个预测变量。

public List<object> Foo(IEnumerable<object> objects)
{
    object firstObject;
    if (objects == null || !TryPeek(ref objects, out firstObject))
        throw new ArgumentException();

    var list = DoSomeThing(firstObject);
    var secondList = DoSomeThingElse(objects);
    list.AddRange(secondList);

    return list;
}

public static bool TryPeek<T>(ref IEnumerable<T> source, out T first)
{
    if (source == null)
        throw new ArgumentNullException(nameof(source));

    IEnumerator<T> enumerator = source.GetEnumerator();
    if (!enumerator.MoveNext())
    {
        first = default(T);
        source = Enumerable.Empty<T>();
        return false;
    }

    first = enumerator.Current;
    T firstElement = first;
    source = Iterate();
    return true;

    IEnumerable<T> Iterate()
    {
        yield return firstElement;
        using (enumerator)
        {
            while (enumerator.MoveNext())
            {
                yield return enumerator.Current;
            }
        }
    }
}

注意:以上时间在不久的将来会减少。