按组对行进行求和(一次多列)

时间:2018-06-14 11:04:09

标签: r data.table

我需要在大量选择列上获取列总和。例如:

library(data.table)

set.seed(123)

DT = data.table(grp = c("A", "B", "C"),
                x1 = sample(1:10, 3),
                x2 = sample(1:10, 3),
                x3 = sample(1:10, 3),
                x4 = sample(1:10, 3))

> DT
   grp x1 x2 x3 x4
1:   A  3  9  6  5
2:   B  8 10  9  9
3:   C  4  1  5  4

说,我想总结x2x3。我通常会这样做:

> DT[, .(total = sum(x2, x3)), by=grp]
   grp total
1:   A    15
2:   B    19
3:   C     6

但是,如果列的范围非常大,比如100,那么如何优雅地编码,而不是按名称拼写每列?

我尝试了什么(以及什么不起作用):

my_cols <- paste0("x", 2:3)

DT[, .(total = sum(get(my_cols))), by=grp]
   grp total
1:   A     9
2:   B    10
3:   C     1

似乎只使用第一列(x2)并忽略其余列。

1 个答案:

答案 0 :(得分:3)

我没有找到确切的欺骗(按行按组处理)所以这里有5种不同的可能性我可以想到。

这里要记住的主要事情是你正在使用每个组的data.table,因此,某些功能在没有unlist

的情况下无法工作
## Create an example data
library(data.table)
set.seed(123)
DT <- data.table(grp = c("A", "B", "C"),
                 matrix(sample(1:10, 30 * 4, replace = TRUE), ncol = 4))

my_cols <- paste0("V", 2:3)

## 1- This won't work with `NA`s. It will work without `unlist`, 
## but won't return correct results.
DT[, Reduce(`+`, unlist(.SD)), .SDcols = my_cols, by = grp]

## 2 - Convert to long format first and then aggregate
melt(DT, "grp", measure = my_cols)[, sum(value), by = grp] 

## 3 - Using `base::sum` which can handle data.frames, 
## see `?S4groupGeneric` (a data.table is also a data.frame)
DT[, base::sum(.SD), .SDcols = my_cols, by = grp]

## 4 - This will use data.tables enhanced `gsum` function,
## but it can't handle data.frames/data.tables
## Hence, requires unlist first. Will be interesting to measure the tradeoff
DT[, sum(unlist(.SD)), .SDcols = my_cols, by = grp]

## 5 - This is a modification to your original attempt that both handles multiple columns
## (`mget` instead of `get`) and adds `unlist` 
## (no point trying wuth `base::sum` instead, because it will also require `unlist`)
DT[, sum(unlist(mget(my_cols))), by = grp]

所有这些都将返回相同的结果

#    grp  V1
# 1:   A 115
# 2:   B 105
# 3:   C  96

一些基准

library(data.table)
library(microbenchmark)
library(stringi)

set.seed(123)
N <- 1e5
cols <- 50
DT <- data.table(grp = stri_rand_strings(N / 1e4, 2),
                 matrix(sample(1:10, N * cols, replace = TRUE), 
                        ncol = cols))
my_cols <- paste0("V", 1:20)


mbench <- microbenchmark(
  "Reduce/unlist: " = DT[, Reduce(`+`, unlist(.SD)), .SDcols = my_cols, by = grp],
  "melt: " = melt(DT, "grp", measure = my_cols)[, sum(value), by = grp], 
  "base::sum: " = DT[, base::sum(.SD), .SDcols = my_cols, by = grp],
  "gsum/unlist: " = DT[, sum(unlist(.SD)), .SDcols = my_cols, by = grp],
  "gsum/mget/unlist: " = DT[, sum(unlist(mget(my_cols))), by = grp]
)

# Unit: milliseconds
#               expr        min         lq       mean     median         uq        max neval cld
#    Reduce/unlist:  1968.93628 2185.45706 2332.66770 2301.10293 2440.43138 3161.15522   100   c
#             melt:    33.91844   58.18254   66.70419   64.52190   74.29494  132.62978   100 a  
#        base::sum:    18.00297   22.44860   27.21083   25.14174   29.20080   77.62018   100 a  
#      gsum/unlist:   780.53878  852.16508  929.65818  894.73892  968.28680 1430.91928   100  b 
# gsum/mget/unlist:   797.99854  876.09773  963.70562  928.27375 1003.04632 1578.76408   100  b 

library(ggplot2)
autoplot(mbench)

enter image description here