在R中计算此三重和的最快方法

时间:2019-05-17 10:59:52

标签: r matrix

我的目标是计算以下三次求和:

I(Y1,Y2,Y3) = 1 if Y[1] < Y[2] < Y[3]
 I(Y1,Y2,Y3) = 1/2 if Y[1] = Y[2] < Y[3]
 I(Y1,Y2,Y3) = 1/6 if Y[1] = Y[2] = Y[3]
 I(Y1,Y2,Y3) = 0 Otherwise.

其中I(Y1,Y2,Y3)定义为:

expand.grid()

我已经用R实现了计算,代码是:

问题在于用这种方法计算非常昂贵。我想这与使用set.seed(123) nclasses = 3 ind <- function(Y){ res = 0 if (Y[1] < Y[2] & Y[2] < Y[3]){res = 1} else if (Y[1] == Y[2] & Y[2] < Y[3]){res = 1/2} else if (Y[1] == Y[2] & Y[2] == Y[3]){res = 1/6} else {res = 0} return (res) } N_obs = 300 c0 <- rnorm(N_obs) l0 = length(c0) c1 <- rnorm(N_obs) l1 = length(c1) c2 <- rnorm(N_obs) l2 = length(c2) mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses) dim(mat) Result <- (1/(l0*l1*l2))*sum(apply(mat, 1, ind)) 创建所有组合的矩阵然后计算结果有关。

有人能有更有效的方法吗?

{{1}}

2 个答案:

答案 0 :(得分:1)

原件在我的计算机上花费了399秒来执行Result <-行。使用dplyrtidyr进行的这种变化花了7秒钟来完成求和部分,而我得到的答案完全相同。我认为加速来自dplyr版本的矢量化,并且可以对所有2700万行进行相同的计算,而我怀疑原始版本每次都会重新计算。

library(dplyr); library(tidyr)

combos <- tibble(Y1 = rnorm(300),
                 Y2 = rnorm(300),
                 Y3 = rnorm(300)) %>%
  complete(Y1, Y2, Y3)

combos %>%
  mutate(res = case_when(Y1  < Y2 & Y2 < Y3  ~ 1,
                         Y1 == Y2 & Y2 < Y3  ~ 1/2,
                         Y1 == Y2 & Y2 == Y3 ~ 1/6,
                         TRUE               ~ 0)) %>%
  summarize(mean_res = mean(res))

这似乎在代数上也是可以解决的,但是我认为这是通过仿真来解决。

如果我们有300个数字分别由16个数字组成的三组单独的一组,每组均使用rnorm绘制,则它们彼此匹配的可能性很小。因此,我们可以忽略第二和第三种情况,这在建议的set.seed中不会发生,并且可能需要数十亿次运行才能遇到一次。

现在Y [1] set.seed(123)中,在27,000,000个案例中,有22,379,120个案例上升(82.9%)。

答案 1 :(得分:1)

tl; dr -使用非等联接的data.table可以在tidyr完成生成数据的相同时间内解决该问题。尽管如此,tidyr / dplyr解决方案看起来还是更好。

data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
  ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
    ][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                      ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
                                      )))
      ] / (length(c0) * length(c1) * length(c2))

有两种加速方式-如何生成数据,然后是计算本身。

生成数据

最快的方法是保持简单。您可以使用as.matrix而不是转置和取消列出,以保持清晰度和轻微的减速。或者,您可以将expand.grid保留为data.frame,类似于创建小标题的tidyr解决方案。

data.table等效项为CJ(c0, c1, c2),比最快的基数或tidyr等效项快约10倍。

#Creating dataset
Unit: milliseconds
                expr     min      lq    mean  median      uq     max neval
            original 1185.10 1239.37 1478.46 1503.68 1690.47 1899.37    10
           as.matrix 1023.49 1041.72 1213.17 1198.24 1360.51 1420.78    10
         expand.grid  764.43  840.11 1030.13 1030.79 1146.82 1354.06    10
      tidyr_complete 2811.00 2948.86 3118.33 3158.59 3290.21 3364.52    10
      tidyr_crossing 1154.94 1171.01 1311.71 1233.40 1545.30 1609.86    10
       data.table_CJ  154.71  155.30  175.65  162.54  174.96  291.14    10

另一种方法是使用非等联接或对数据进行预过滤。我们知道,如果c0 > c1c1 > c2的求和结果为0。这样,我们可以过滤掉我们知道不需要存储到内存的组合,从而更快地创建了组合。

虽然这两种方法都比data.table::CJ()慢,但它们为三次加法设置了更好的舞台。

# 'data.table_CJ_filter' = CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
#'tidyr_cross_filter' =  crossing(c0, c1) %>% filter(c0 <= c1) %>% crossing(c2) %>% filter(c1 <= c2)

#Creating dataset with future calcs in mind
Unit: milliseconds
                 expr    min     lq   mean median      uq     max neval
  data.table_non_equi 358.41 360.35 373.95 374.57  383.62  400.42    10
 data.table_CJ_filter 515.50 517.99 605.06 527.63  661.54  856.43    10
   tidyr_cross_filter 776.91 783.35 980.19 928.25 1178.47 1287.91    10

计算总和

@Jon Spring的解决方案很棒。 case_whenifelse是矢量化的,而原始if ... else则不是。我将乔恩(Jon)的答案翻译成BaseR。它比您原来的解决方案要快,但仍比dplyr花费约50%。

请注意,如果您执行了非等价联接,则可以进一步简化case_when,因为我们已经进行了过滤-剩下的所有行都变为1、1 / 2或1/6。请注意,预过滤的解决方案的速度比未过滤的数据快10到30倍。

Unit: milliseconds
             expr     min      lq    mean  median      uq     max neval
             base 5666.93 6003.87 6303.27 6214.58 6416.42 7423.30    10
            dplyr 3633.48 3963.47 4160.68 4178.15 4395.96 4530.15    10
       data.table  236.83  262.10  305.19  268.47  269.44  495.22    10
 dplyr_pre_filter  378.79  387.38  459.67  418.58  448.13  765.74    10

放在一起

开始时提供的最终解决方案不到一秒钟。 dplyr修订版少于2秒。两种解决方案在进入逻辑if ... else语句之前都依赖于预过滤。

Unit: milliseconds
      expr     min      lq    mean  median      uq    max neval
    dt_res  589.83  608.26  736.34  642.46  760.18 1091.1    10
 dt_CJ_res  750.07  764.78  905.12  893.73 1040.21 1140.5    10
 dplyr_res 1156.69 1169.84 1363.82 1337.42 1496.60 1709.8    10

数据/代码

# https://stackoverflow.com/questions/56185072/fastest-way-to-compute-this-triple-summation-in-r
library(dplyr)
library(tidyr)
library(data.table)

options(digits = 5)
set.seed(123)

nclasses = 3
N_obs = 300

c0 <- rnorm(N_obs)
c1 <- rnorm(N_obs)
c2 <- rnorm(N_obs)

# Base R Data Generation --------------------------------------------------

mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
df <- expand.grid(c0,c1,c2)

identical(mat, unname(as.matrix(df))) #TRUE - names are different with as.matrix

# tidyr and data.table Data Generation ------------------------------------

tib <- crossing(c0, c1, c2) #faster than complete

tib2 <- crossing(c0, c1)%>% #faster but similar in concept to non-equi
  filter(c0 <= c1)%>%
  crossing(c2)%>%
  filter(c1 <= c2)

dt <-   data.table(c0
                   )[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
                     ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
                       ][c0 <= c1 & c1 <= c2, ]

# Base R summation --------------------------------------------------------

sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
                      ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
                             ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
                      ))
    ) / (length(c0)*length(c1)*length(c2))


# dplyr summation ---------------------------------------------------------

tib %>%
  mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                         c0 == c1 & c1 < c2  ~ 1/2,
                         c0 == c1 & c1 == c2 ~ 1/6,
                         TRUE               ~ 0)) %>%
  summarize(mean_res = mean(res))

# data.table summation ----------------------------------------------------

#why base doesn't have case_when, who knows
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
                ifelse(c0 == c1 & c1 < c2, 1/2,
                       ifelse(c0 == c1 & c1 == c2, 1/6)
                )))
   ] / (length(c0) * length(c1) * length(c2))


CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                             ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
                                             )))
             ] / (length(c0) * length(c1) * length(c2))

# Benchmarking ------------------------------------------------------------

library(microbenchmark)

# Data generation
microbenchmark('original' = {
  matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
}
, 'as.matrix' = {
  as.matrix(expand.grid(c0,c1,c2)) 
}
, 'expand.grid' = {
  expand.grid(c0,c1,c2) #keep it simpler
}
, 'tidyr_complete' = {
  tibble(c0, c1, c2) %>% complete(c0, c1, c2)
}
, 'tidyr_crossing' = {
  crossing(c0, c1, c2)
}
, 'data.table_CJ' = {
  CJ(c0,c1,c2)
}
, times = 10)

microbenchmark('data.table_non_equi' = {
  data.table(c0
             )[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
               ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
                 ][c0 <= c1 & c1 <= c2, ]
}
, 'data.table_CJ_filter' = {
  CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
}
, 'tidyr_cross_filter' = {
  crossing(c0,c1)%>%filter(c0 <= c1)%>% crossing(c2)%>% filter(c1 <= c2)
}
, times = 10
)

# Summation Calculation
microbenchmark('base' = {
  sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
             ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
                    ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
             ))
  ) / (length(c0)*length(c1)*length(c2))
}
, 'dplyr' = {
  tib %>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           c0 == c1 & c1 == c2 ~ 1/6,
                           TRUE               ~ 0)) %>%
    summarize(mean_res = mean(res))
}
, 'data.table' = {
  dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
                  ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
                  ))
     ] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_pre_filter' = {
  tib2 %>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           TRUE ~ 1/6)) %>%
    summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10)

# Start to Finish

microbenchmark('dt_res' = {
  data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
  ][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
    ][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                      ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
    ))
    ] / (length(c0) * length(c1) * length(c2))
}
, 'dt_CJ_res' = {
  CJ(c0, c1, c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
                                                 ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
  ))
  ] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_res' = {
  crossing(c0, c1)%>% #faster but similar in concept to non-equi
    filter(c0 <= c1)%>%
    crossing(c2)%>%
    filter(c1 <= c2)%>%
    mutate(res = case_when(c0  < c1 & c1 < c2  ~ 1,
                           c0 == c1 & c1 < c2  ~ 1/2,
                           TRUE ~ 1/6)) %>%
    summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10
)