R:扫描矢量一次而不是4次?

时间:2015-01-14 17:20:17

标签: r

假设我有两个相等长度的逻辑向量。 以简单的方式计算confusion matrix

c(sum(actual == 1 & predicted == 1),
  sum(actual == 0 & predicted == 1),
  sum(actual == 1 & predicted == 0),
  sum(actual == 0 & predicted == 0))

需要扫描矢量 4 次。

是否可以一次性完成?

PS。我尝试了table(2*actual+predicted)table(actual,predicted),但两者显然都慢得多。

PPS。速度不是我的主要考虑因素,我对理解语言更感兴趣。

3 个答案:

答案 0 :(得分:6)

您可以尝试使用data.table

library(data.table)
DT <- data.table(actual, predicted)
setkey(DT, actual, predicted)[,.N, .(actual, predicted)]$N

数据

set.seed(24)
actual <- sample(0:1, 10 , replace=TRUE)
predicted <- sample(0:1, 10, replace=TRUE)

基准

使用data.table_1.9.5dplyr_0.4.0

library(microbenchmark)
set.seed(245)
actual <- sample(0:1, 1e6 , replace=TRUE)
predicted <- sample(0:1, 1e6, replace=TRUE)
f1 <- function(){
  DT <- data.table(actual, predicted)
  setkey(DT, actual, predicted)[,.N, .(actual, predicted)]$N}

f2 <- function(){table(actual, predicted)}
f3 <- function() {data_frame(actual, predicted) %>%
                      group_by(actual, predicted) %>% 
                      summarise(n())}

microbenchmark(f1(), f2(), f3(), unit='relative', times=20L)
#Unit: relative
# expr       min        lq      mean   median        uq       max neval cld
#f1()  1.000000  1.000000  1.000000  1.00000  1.000000  1.000000    20  a 
#f2() 20.818410 22.378995 22.321816 22.56931 22.140855 22.984667    20   b
#f3()  1.262047  1.248396  1.436559  1.21237  1.220109  2.504662    20  a 

countdplyr中的tabulate包括在稍微更大的数据集的基准中

set.seed(498)
actual <- sample(0:1, 1e7 , replace=TRUE)
predicted <- sample(0:1, 1e7, replace=TRUE)
f4 <- function() {data_frame(actual, predicted) %>% 
                       count(actual, predicted)}
f5 <- function(){tabulate(4-actual-2*predicted, 4)}

更新

在基准测试中包含另一个data.table解决方案(由@Arun提供)

f6 <- function() {setDT(list(actual, predicted))[,.N, keyby=.(V1,V2)]$N}

microbenchmark(f1(),  f3(), f4(), f5(), f6(),  unit='relative', times=20L)
#Unit: relative
#expr      min       lq     mean   median       uq      max neval  cld
#f1() 2.003088 1.974501 2.020091 2.015193 2.080961 1.924808    20   c 
#f3() 2.488526 2.486019 2.450749 2.464082 2.481432 2.141309    20    d
#f4() 2.388386 2.423604 2.430581 2.459973 2.531792 2.191576    20    d
#f5() 1.034442 1.125585 1.192534 1.217337 1.239453 1.294920    20  b  
#f6() 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000    20 a   

答案 1 :(得分:5)

像这样:

tabulate(4 - actual - 2*predicted, 4)

tabulate这里比table快得多,因为它知道输出将是长度为4的向量。

答案 2 :(得分:2)

如果tableactual仅包含0和1,则会predicted计算交叉制表,并且会给出类似的结果:

table(actual, predicted)

在内部,这可以通过paste向量来实现 - 非常低效。似乎只有在列出一个值时才会强制character,这也可能是table(actual*2 + predicted)表现不佳的原因。