如何评估data.table中具有不同条件的列

时间:2014-07-10 03:09:34

标签: r data.table

鉴于data.table如下:

library(data.table)
set.seed(100)
dt <- data.table(a=c(1:3, 1), b = c(1,0,1, 3), c = c(1,2,1,3), x = rnorm(4), y = rnorm(4), d = c(4, 6, 6, 7)) 

dt返回,

   a b c           x          y d
1: 1 1 1 -0.50219235  0.1169713 4
2: 2 0 2  0.13153117  0.3186301 6
3: 3 1 1 -0.07891709 -0.5817907 6
4: 1 3 3  0.88678481  0.7145327 7

列中的任何数字&#34; a&#34;,&#34; b&#34;和&#34; c&#34;等于3的将为TRUE

此外,列中的任何数字&#34; d&#34;等于6的将为TRUE

如何使用列的名称评估dt内部(&#34; a&#34;,&#34; b&#34;,&#34; c&#34;,和&#34; d&#34)

这样我的回报就是:

       a     b     c           x          y     d
1: FALSE FALSE FALSE -0.50219235  0.1169713 FALSE
2: FALSE FALSE FALSE  0.13153117  0.3186301  TRUE
3:  TRUE FALSE FALSE -0.07891709 -0.5817907  TRUE
4: FALSE  TRUE  TRUE  0.88678481  0.7145327 FALSE

谢谢

4 个答案:

答案 0 :(得分:9)

我提出的方法如下:

dt[, c("a", "b", "c") := lapply(.SD, `==`, 3), 
   .SDcols = c("a", "b", "c")][, d := (d == 6)][]
#        a     b     c           x          y     d
# 1: FALSE FALSE FALSE -0.50219235  0.1169713 FALSE
# 2: FALSE FALSE FALSE  0.13153117  0.3186301  TRUE
# 3:  TRUE FALSE FALSE -0.07891709 -0.5817907  TRUE
# 4: FALSE  TRUE  TRUE  0.88678481  0.7145327 FALSE

在可读性方面它并没有赢得任何分数,但在性能方面似乎没有问题。

这里有一些要测试的样本数据:

library(data.table)
set.seed(100)
Nrow = 3000000
dt <- data.table(a = sample(10, Nrow, TRUE), 
                 b = sample(10, Nrow, TRUE), 
                 c = sample(10, Nrow, TRUE), 
                 x = rnorm(Nrow), 
                 y = rnorm(Nrow),
                 d = sample(10, Nrow, TRUE)) 

......一些要测试的功能......

fun1 <- function(indt) {
  indt[, c("a", "b", "c") := lapply(.SD, `==`, 3), 
     .SDcols = c("a", "b", "c")][, d := (d == 6)][]
}

fun2 <- function(indt) {
  for (i in c("a","b","c")) indt[, (i):=get(i)==3]
  for (i in c("d"))         indt[, (i):=get(i)==6]
  indt
}

fun3 <- function(indt) {
  f <- function(col,x) indt[,(col):=(.SD==x),.SDcols=col]
  lapply(list("a","b","c"), f, 3)
  lapply(list("d"), f, 6)
  indt
}

......和一些时间......

microbenchmark(fun1(copy(dt)), fun2(copy(dt)), fun3(copy(dt)), times = 10)
# Unit: milliseconds
#            expr      min        lq    median        uq       max neval
#  fun1(copy(dt)) 518.6034  535.0848  550.3178  643.2968  695.5819    10
#  fun2(copy(dt)) 830.5808 1037.8790 1172.6684 1272.6236 1608.9753    10
#  fun3(copy(dt)) 922.6474 1029.8510 1097.7520 1145.1848 1340.2009    10

identical(fun1(copy(dt)), fun2(copy(dt)))
# [1] TRUE
identical(fun2(copy(dt)), fun3(copy(dt)))
# [1] TRUE

在这个范围内,我会选择最适合你的东西(除非那些毫秒真的很重要),但如果你的数据较大,你可能想要尝试更多的不同选项。


Matt的补充

同意。要跟进评论,请点击这里fun4,但它只是这个尺寸上最快的smidgen(3e6行,90MB)

fun4 <- function(indt) {
  for (i in c("a","b","c")) set(indt,NULL,i,indt[[i]]==3)
  for (i in c("d"))         set(indt,NULL,i,indt[[i]]==6)
  indt
}

microbenchmark(copy(dt), fun1(copy(dt)), fun2(copy(dt)), fun3(copy(dt)), 
               fun4(copy(dt)), times = 10)
# Unit: milliseconds
#            expr        min         lq     median         uq       max neval
#        copy(dt)   64.13398   65.94222   68.32217   82.39942  110.3293    10
#  fun1(copy(dt))  601.84611  618.69288  690.47179  713.56760  766.1534    10
#  fun2(copy(dt))  887.99727  950.33821  978.98988 1071.31253 1180.1281    10
#  fun3(copy(dt)) 1566.90858 1574.30635 1603.55467 1673.38625 1771.4054    10
#  fun4(copy(dt))  566.43528  568.91103  575.06881  672.44021  692.9839    10

> identical(fun1(copy(dt)), fun4(copy(dt)))
[1] TRUE

接下来,我将数据大小增加了10倍,达到3000万行,即915MB。

请注意,这些时间现在只需几秒钟,而且在我的慢速上网本上。

set.seed(100)
Nrow = 30000000
dt <- data.table(a = sample(10, Nrow, TRUE), 
              b = sample(10, Nrow, TRUE), 
              c = sample(10, Nrow, TRUE), 
              x = rnorm(Nrow), 
              y = rnorm(Nrow),
              d = sample(10, Nrow, TRUE)) 
object.size(dt)/1024^2
# 915 MB
microbenchmark(copy(dt),fun1(copy(dt)), fun2(copy(dt)), fun3(copy(dt)), 
                 fun4(copy(dt)), times = 3)
# Unit: seconds
#            expr       min        lq    median       uq      max neval
#        copy(dt)   8.04262  53.68556  99.32849 269.4414 439.5544     3
#  fun1(copy(dt)) 207.70646 260.16710 312.62775 317.8966 323.1654     3
#  fun2(copy(dt)) 421.78934 502.03503 582.28073 658.0680 733.8553     3
#  fun3(copy(dt)) 104.30914 187.49875 270.68836 384.7804 498.8724     3
#  fun4(copy(dt)) 158.17239 165.35898 172.54557 183.4851 194.4246     3

在这里,{I}平均来说,fun4平均速度相当快,因为​​for循环的内存效率一次只有一列。在fun1fun3中,:=的RHS在此之前是三列宽,然后分配给三个目标列。话虽如此,为什么我之前fun2最慢?毕竟它逐列。在进入get()之前,可能会==复制该列。

有一次跑fun3最快(104 vs 158)。我不确定我是否相信microbenchmark。我似乎记得Radford Neal对microbenchmark的一些批评,但不记得结果。

这些时间安排在我真正慢的上网本上:

$ lscpu
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                2
On-line CPU(s) list:   0,1
Thread(s) per core:    1
Core(s) per socket:    2
Socket(s):             1
NUMA node(s):          1
Vendor ID:             AuthenticAMD
CPU family:            20
Model:                 2
Stepping:              0
CPU MHz:               800.000
BogoMIPS:              1995.06
Virtualisation:        AMD-V
L1d cache:             32K
L1i cache:             32K
L2 cache:              512K
NUMA node0 CPU(s):     0,1

> sessionInfo()
R version 3.1.0 (2014-04-10)
Platform: x86_64-pc-linux-gnu (64-bit)   

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] microbenchmark_1.3-0 data.table_1.9.2     bit64_0.9-3          bit_1.1-11

答案 1 :(得分:3)

这似乎符合你的要求:

f <- function(col,x) dt[,(col):=(.SD==x),.SDcols=col]
lapply(list("a","b","c"), f, 3)
lapply(list("d"), f, 6)
dt
#        a     b     c           x          y     d
# 1: FALSE FALSE FALSE -0.50219235  0.1169713 FALSE
# 2: FALSE FALSE FALSE  0.13153117  0.3186301  TRUE
# 3:  TRUE FALSE FALSE -0.07891709 -0.5817907  TRUE
# 4: FALSE  TRUE  TRUE  0.88678481  0.7145327 FALSE

注意:

  • 第二个lapply(...)可以替换为:dt[,d:=(d==6)] 但我想知道你的真实案例是否比这更复杂。
  • 此方法不适用于数据框。原因很微妙 并且与通过引用更新的数据表有关。所以举个例子 如果将数据帧传递给函数,则按值传递 - 副本 是。如果在函数中修改它,则表示正在修改 复制;原件没有变化。另一方面,如果您传递数据 table到函数,它通过引用传递,并进行修改 函数内部反映在原始数据表中。

答案 2 :(得分:3)

我首先尝试的是:

> dt
   a b c           x          y d
1: 1 1 1 -0.50219235  0.1169713 4
2: 2 0 2  0.13153117  0.3186301 6
3: 3 1 1 -0.07891709 -0.5817907 6
4: 1 3 3  0.88678481  0.7145327 7
> for (i in c("a","b","c")) dt[get(i)==3, (i):=TRUE]
> dt[d==6, d:=TRUE]

但这得到了错误的答案:

> dt
   a b c           x          y d
1: 1 1 1 -0.50219235  0.1169713 4
2: 2 0 2  0.13153117  0.3186301 1
3: 1 1 1 -0.07891709 -0.5817907 1
4: 1 1 1  0.88678481  0.7145327 7
> 

这是因为:=的RHS被强制匹配列的类型,即TRUE在这种情况下被强制为1。但是你想要改变列的类型,这在data.table中是故意的。想象一下RAM中的20GB data.table - 你几乎不想改变列类型,因为这将涉及复制整个列。 99%的时间您希望将RHS强制转换为列的类型,例如:=1如果列类型为:=1L则强制为integer

要更改列类型,您需要 plonk 一个新列直接进入该列指针槽。你可以通过使RHS与行数一样长。

> for (i in c("a","b","c")) dt[, (i):=get(i)==3]
> for (i in c("d"))         dt[, (i):=get(i)==6]
> dt
       a     b     c           x          y     d
1: FALSE FALSE FALSE -0.50219235  0.1169713 FALSE
2: FALSE FALSE FALSE  0.13153117  0.3186301  TRUE
3:  TRUE FALSE FALSE -0.07891709 -0.5817907  TRUE
4: FALSE  TRUE  TRUE  0.88678481  0.7145327 FALSE
>

答案 3 :(得分:2)

这里的另一种方法在速度方面远远落后于前两种方法,可能更具可读性:

# the variables and values you want
vars = c('a','b','c','d')
values = c(3,3,3,6)

dt[, (vars) := Map('==', .SD, values), .SDcols = vars]