使用data.table获取lapply组中的条件总和

时间:2018-07-20 13:46:50

标签: r data.table lapply

我有一个data.table,其中每一行都是一个具有开始日期和结束日期的事件,但是每个开始和结束之间的天数是可变的。 因此,我试图计算在每个事件开始时已经结束了多少其他事件。 我可以使用lapply来做到这一点,但是当我尝试将data.tableby功能一起使用时,却没有得到预期的输出。下面的示例代码:

library(data.table)

DT <- data.table(
  start = as.Date(c("2018-07-01","2018-07-03","2018-07-06","2018-07-08","2018-07-12","2018-07-15")),
  end = as.Date(c("2018-07-10","2018-07-04","2018-07-09","2018-07-20","2018-07-14","2018-07-27")),
  group_id = c("a", "a", "a", "b", "b", "b"))

# This produces the expected output (0,0,1,1,3,4):
lapply(DT$start, function(x) sum(x > DT$end))

# This also works using data.table:
DT[, count := lapply(DT$start, function(x) sum(x > DT$end))]

# However, I don't get the expected output (0,0,1,0,0,1) when I attempt to do this by group_id
DT[, count_by_group := lapply(DT$start, function(x) sum(x > DT$end)), by = group_id]

使用以下输出,其中count_by_group不是预期结果:

        start        end group_id count count_by_group
1: 2018-07-01 2018-07-10        a     0              0
2: 2018-07-03 2018-07-04        a     0              0
3: 2018-07-06 2018-07-09        a     1              0
4: 2018-07-08 2018-07-20        b     1              0
5: 2018-07-12 2018-07-14        b     3              0
6: 2018-07-15 2018-07-27        b     4              0

有人可以帮助我了解by如何改变行为吗?我还尝试使用.SD功能的不同版本,但也无法使其正常工作。

2 个答案:

答案 0 :(得分:2)

unlist()

unlist()也可以工作:

DT[, count_by_group := unlist(lapply(start, function(x) sum(x > end))), by = group_id]

不参加比赛

或者,也可以通过聚合非等价自联接来解决此问题

DT[, count_by_group := DT[DT, on = .(group_id, end < start), .N, by = .EACHI]$N]
DT
        start        end group_id count_by_group
1: 2018-07-01 2018-07-10        a              0
2: 2018-07-03 2018-07-04        a              0
3: 2018-07-06 2018-07-09        a              1
4: 2018-07-08 2018-07-20        b              0
5: 2018-07-12 2018-07-14        b              0
6: 2018-07-15 2018-07-27        b              1

基准

对于具有几百行的案例,非等额联接也是最快的方法:

library(bench)
bm <- press(
  n_grp = c(2L, 5L, 10L),
  n_row = 10^(2:4),
  {
    set.seed(1L)
    DT = data.table(
      group_id = sample.int(n_grp, n_row, TRUE),
      start = as.Date("2018-07-01") + rpois(n_row, 20L))
    DT[, end := start + rpois(n_row, 10L)]
    setorder(DT, group_id, start, end)
    mark(
      unlist = copy(DT)[, count_by_group := unlist(lapply(start, function(x) sum(x > end))), by = group_id],
      sapply = copy(DT)[, count_by_group := sapply(start, function(x) sum(x > end)), by = group_id],
      vapply = copy(DT)[, count_by_group := vapply(start, function(x) sum(x > end), integer(1)), by = group_id],
      nej = copy(DT)[, count_by_group := DT[DT, on = .(group_id, end < start), .N, by = .EACHI]$N]
    )
  }
)
ggplot2::autoplot(bm)

enter image description here

对于10000行,非等值联接比其他方法快约10倍。

随着DT的更新,copy()用于为每次基准测试运行创建DT的未经修改的全新副本。

答案 1 :(得分:0)

DT[, count_by_group := vapply(start, function(x) sum(x > end), integer(1)), by = group_id]

要按组引用startend,我们需要省略DT$前缀。
我们使用vapply()而不是lapply(),因为如果:=的右侧是一个列表,它将被解释为列的列表(并且由于只需要一个列,所以只有第一个元素0被考虑并回收)。