R data.table:计算分组频率

时间:2014-06-07 02:50:50

标签: r aggregate data.table histogram

我正在尝试向我的data.table添加列,这实际上为每个聚合的组附加了一个累积频率表。不幸的是,我目前的解决方案比我希望的慢十倍。

这是我正在使用的(为丑陋的单行道歉):

DT[, c("bin1","bin2","bin3","bin4") := as.list(cumsum(hist(colx,c(lbound,bound1,bound2, bound3,ubound),plot=FALSE)$counts)), by=category]

如果bin边界设置为0,25,50,75,100,我希望我的表格看起来像:

id category colx bin1 bin2 bin3 bin4
1  a        5    1    2    2    3
2  a        30   1    2    2    3
3  b        21   1    2    3    4
4  c        62   0    1    3    3
5  b        36   1    2    3    4
6  a        92   1    2    2    3
7  c        60   0    1    3    3
8  b        79   1    2    3    4
9  b        54   1    2    3    4
10 c        27   0    1    3    3

在实际的数据集中,我使用4个不同的列进行分组,并且有数百万行和唯一的组。当我尝试更简单的函数(例如sum)时,需要一段可接受的时间来进行计算。有没有办法显着加快计数过程?

1 个答案:

答案 0 :(得分:1)

好的,这是一种方式(这里我使用data.table v1.9.3)。如果您使用的是by=.EACHI版本,请移除<= 1.9.2

dt[, ival := findInterval(colx, seq(0, 100, by=25), rightmost.closed=TRUE)]
setkey(dt, category, ival)
ans <- dt[CJ(unique(category), unique(ival)), .N, allow.cartesian=TRUE, by=.EACHI]
ans[, N := cumsum(N), by="category"][, bin := "bin"]
ans <- dcast.data.table(ans, category ~ bin+ival, value.var="N")
ans <- dt[ans][, ival := NULL]

    id category colx bin_1 bin_2 bin_3 bin_4
 1:  1        a    5     1     2     2     3
 2:  2        a   30     1     2     2     3
 3:  6        a   92     1     2     2     3
 4:  3        b   21     1     2     3     4
 5:  5        b   36     1     2     3     4
 6:  9        b   54     1     2     3     4
 7:  8        b   79     1     2     3     4
 8: 10        c   27     0     1     3     3
 9:  4        c   62     0     1     3     3
10:  7        c   60     0     1     3     3

模拟大数据的基准:

我在这里生成一个包含2000万行的data.table和总共100万个包含2个分组列的组(而不是你在问题中陈述的4个)。

K = 1e3L
N = 20e6L
sim_data <- function(K, N) {
    set.seed(1L)
    ff <- function(K, N) sample(paste0("V", 1:K), N, TRUE)
    data.table(x=ff(K,N), y=ff(K,N), val=sample(1:100, N, TRUE))
}

dt <- sim_data(K, N)
method1 <- function(x) { 
    dt[, ival := findInterval(val, seq(0, 100, by=25), rightmost.closed=TRUE)]
    setkey(dt, x, y, ival)
    ans <- dt[CJ(unique(x), unique(y), unique(ival)), .N, allow.cartesian=TRUE, by=.EACHI]
    ans[, N := cumsum(N), by="x,y"][, bin := "bin"]
    ans <- dcast.data.table(ans, x+y ~ bin+ival, value.var="N")
    ans <- dt[ans][, ival := NULL]
}

system.time(ans1 <- method1(dt))
#   user  system elapsed 
# 13.148   2.778  16.209 

我希望这比原始解决方案更快,并且可以很好地扩展您的实际数据维度。


更新:以下是使用data.table's 滚动加入而非基座的 findInterval 的另一个版本。我们要稍微修改间隔,以便滚动连接找到正确的匹配。

dt <- sim_data(K, N)
method2 <- function(x) {
    ivals = seq(24L, 100L, by=25L)
    ivals[length(ivals)] = 100L
    setkey(dt, x,y,val)
    dt[, ival := seq_len(.N), by="x,y"]
    ans <- dt[CJ(unique(x), unique(y), ivals), roll=TRUE, mult="last"][is.na(ival), ival := 0L][, bin := "bin"]
    ans <- dcast.data.table(ans, x+y~bin+val, value.var="ival")
    dt[, ival := NULL]
    ans2 <- dt[ans]
}

system.time(ans2 <- method2(dt))
#   user  system elapsed 
# 12.538   2.649  16.079 

## check if both methods give identical results:

setkey(ans1, x,y,val)
setnames(ans2, copy(names(ans1)))
setkey(ans2, x,y,val)

identical(ans1, ans2) # [1] TRUE

编辑:有关OP为何非常耗时的一些解释:

我怀疑,这些解决方案与hist之间的运行时差异的一个重要原因是这里的答案都是矢量化的(完全用C语言编写并且可以直接在整个数据集上工作),其中as hist是一个S3方法(它会花费时间调度到.default方法并添加到它中,它用R编写。所以,基本上你执行的时间大约是一百万{{1} R中的一个函数,其他两个矢量化解决方案在C中调用一次(这里不需要为每个组调用)。

因为这是你问题中最复杂的部分,所以它显然会减慢速度。