我的目标是计算以下三次求和:
I(Y1,Y2,Y3) = 1 if Y[1] < Y[2] < Y[3]
I(Y1,Y2,Y3) = 1/2 if Y[1] = Y[2] < Y[3]
I(Y1,Y2,Y3) = 1/6 if Y[1] = Y[2] = Y[3]
I(Y1,Y2,Y3) = 0 Otherwise.
其中I(Y1,Y2,Y3)定义为:
expand.grid()
我已经用R实现了计算,代码是:
问题在于用这种方法计算非常昂贵。我想这与使用set.seed(123)
nclasses = 3
ind <- function(Y){
res = 0
if (Y[1] < Y[2] & Y[2] < Y[3]){res = 1}
else if (Y[1] == Y[2] & Y[2] < Y[3]){res = 1/2}
else if (Y[1] == Y[2] & Y[2] == Y[3]){res = 1/6}
else {res = 0}
return (res)
}
N_obs = 300
c0 <- rnorm(N_obs)
l0 = length(c0)
c1 <- rnorm(N_obs)
l1 = length(c1)
c2 <- rnorm(N_obs)
l2 = length(c2)
mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
dim(mat)
Result <- (1/(l0*l1*l2))*sum(apply(mat, 1, ind))
创建所有组合的矩阵然后计算结果有关。
有人能有更有效的方法吗?
{{1}}
答案 0 :(得分:1)
原件在我的计算机上花费了399秒来执行Result <-
行。使用dplyr
和tidyr
进行的这种变化花了7秒钟来完成求和部分,而我得到的答案完全相同。我认为加速来自dplyr
版本的矢量化,并且可以对所有2700万行进行相同的计算,而我怀疑原始版本每次都会重新计算。
library(dplyr); library(tidyr)
combos <- tibble(Y1 = rnorm(300),
Y2 = rnorm(300),
Y3 = rnorm(300)) %>%
complete(Y1, Y2, Y3)
combos %>%
mutate(res = case_when(Y1 < Y2 & Y2 < Y3 ~ 1,
Y1 == Y2 & Y2 < Y3 ~ 1/2,
Y1 == Y2 & Y2 == Y3 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
这似乎在代数上也是可以解决的,但是我认为这是通过仿真来解决。
如果我们有300个数字分别由16个数字组成的三组单独的一组,每组均使用rnorm绘制,则它们彼此匹配的可能性很小。因此,我们可以忽略第二和第三种情况,这在建议的set.seed
中不会发生,并且可能需要数十亿次运行才能遇到一次。
现在Y [1]
答案 1 :(得分:1)
tl; dr -使用非等联接的data.table可以在tidyr
完成生成数据的相同时间内解决该问题。尽管如此,tidyr
/ dplyr
解决方案看起来还是更好。
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
)))
] / (length(c0) * length(c1) * length(c2))
有两种加速方式-如何生成数据,然后是计算本身。
最快的方法是保持简单。您可以使用as.matrix
而不是转置和取消列出,以保持清晰度和轻微的减速。或者,您可以将expand.grid
保留为data.frame,类似于创建小标题的tidyr
解决方案。
data.table等效项为CJ(c0, c1, c2)
,比最快的基数或tidyr等效项快约10倍。
#Creating dataset
Unit: milliseconds
expr min lq mean median uq max neval
original 1185.10 1239.37 1478.46 1503.68 1690.47 1899.37 10
as.matrix 1023.49 1041.72 1213.17 1198.24 1360.51 1420.78 10
expand.grid 764.43 840.11 1030.13 1030.79 1146.82 1354.06 10
tidyr_complete 2811.00 2948.86 3118.33 3158.59 3290.21 3364.52 10
tidyr_crossing 1154.94 1171.01 1311.71 1233.40 1545.30 1609.86 10
data.table_CJ 154.71 155.30 175.65 162.54 174.96 291.14 10
另一种方法是使用非等联接或对数据进行预过滤。我们知道,如果c0 > c1
或c1 > c2
的求和结果为0。这样,我们可以过滤掉我们知道不需要存储到内存的组合,从而更快地创建了组合。
虽然这两种方法都比data.table::CJ()
慢,但它们为三次加法设置了更好的舞台。
# 'data.table_CJ_filter' = CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
#'tidyr_cross_filter' = crossing(c0, c1) %>% filter(c0 <= c1) %>% crossing(c2) %>% filter(c1 <= c2)
#Creating dataset with future calcs in mind
Unit: milliseconds
expr min lq mean median uq max neval
data.table_non_equi 358.41 360.35 373.95 374.57 383.62 400.42 10
data.table_CJ_filter 515.50 517.99 605.06 527.63 661.54 856.43 10
tidyr_cross_filter 776.91 783.35 980.19 928.25 1178.47 1287.91 10
@Jon Spring的解决方案很棒。 case_when
和ifelse
是矢量化的,而原始if ... else
则不是。我将乔恩(Jon)的答案翻译成BaseR。它比您原来的解决方案要快,但仍比dplyr
花费约50%。
请注意,如果您执行了非等价联接,则可以进一步简化case_when
,因为我们已经进行了过滤-剩下的所有行都变为1、1 / 2或1/6。请注意,预过滤的解决方案的速度比未过滤的数据快10到30倍。
Unit: milliseconds
expr min lq mean median uq max neval
base 5666.93 6003.87 6303.27 6214.58 6416.42 7423.30 10
dplyr 3633.48 3963.47 4160.68 4178.15 4395.96 4530.15 10
data.table 236.83 262.10 305.19 268.47 269.44 495.22 10
dplyr_pre_filter 378.79 387.38 459.67 418.58 448.13 765.74 10
开始时提供的最终解决方案不到一秒钟。 dplyr
修订版少于2秒。两种解决方案在进入逻辑if ... else
语句之前都依赖于预过滤。
Unit: milliseconds
expr min lq mean median uq max neval
dt_res 589.83 608.26 736.34 642.46 760.18 1091.1 10
dt_CJ_res 750.07 764.78 905.12 893.73 1040.21 1140.5 10
dplyr_res 1156.69 1169.84 1363.82 1337.42 1496.60 1709.8 10
数据/代码
# https://stackoverflow.com/questions/56185072/fastest-way-to-compute-this-triple-summation-in-r
library(dplyr)
library(tidyr)
library(data.table)
options(digits = 5)
set.seed(123)
nclasses = 3
N_obs = 300
c0 <- rnorm(N_obs)
c1 <- rnorm(N_obs)
c2 <- rnorm(N_obs)
# Base R Data Generation --------------------------------------------------
mat <- matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
df <- expand.grid(c0,c1,c2)
identical(mat, unname(as.matrix(df))) #TRUE - names are different with as.matrix
# tidyr and data.table Data Generation ------------------------------------
tib <- crossing(c0, c1, c2) #faster than complete
tib2 <- crossing(c0, c1)%>% #faster but similar in concept to non-equi
filter(c0 <= c1)%>%
crossing(c2)%>%
filter(c1 <= c2)
dt <- data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, ]
# Base R summation --------------------------------------------------------
sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
))
) / (length(c0)*length(c1)*length(c2))
# dplyr summation ---------------------------------------------------------
tib %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
c0 == c1 & c1 == c2 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
# data.table summation ----------------------------------------------------
#why base doesn't have case_when, who knows
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2,
ifelse(c0 == c1 & c1 == c2, 1/6)
)))
] / (length(c0) * length(c1) * length(c2))
CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6
)))
] / (length(c0) * length(c1) * length(c2))
# Benchmarking ------------------------------------------------------------
library(microbenchmark)
# Data generation
microbenchmark('original' = {
matrix(unlist(t(matrix(expand.grid(c0,c1,c2)))), ncol= nclasses)
}
, 'as.matrix' = {
as.matrix(expand.grid(c0,c1,c2))
}
, 'expand.grid' = {
expand.grid(c0,c1,c2) #keep it simpler
}
, 'tidyr_complete' = {
tibble(c0, c1, c2) %>% complete(c0, c1, c2)
}
, 'tidyr_crossing' = {
crossing(c0, c1, c2)
}
, 'data.table_CJ' = {
CJ(c0,c1,c2)
}
, times = 10)
microbenchmark('data.table_non_equi' = {
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, ]
}
, 'data.table_CJ_filter' = {
CJ(c0,c1,c2)[c0 <= c1 & c1 <= c2, ]
}
, 'tidyr_cross_filter' = {
crossing(c0,c1)%>%filter(c0 <= c1)%>% crossing(c2)%>% filter(c1 <= c2)
}
, times = 10
)
# Summation Calculation
microbenchmark('base' = {
sum(ifelse(df$Var1 < df$Var2 & df$Var2 < df$Var3, 1,
ifelse(df$Var1 == df$Var2 & df$Var2 < df$Var3, 1/2,
ifelse(df$Var1 == df$Var2 & df$Var2 == df$Var3, 1/6, 0)
))
) / (length(c0)*length(c1)*length(c2))
}
, 'dplyr' = {
tib %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
c0 == c1 & c1 == c2 ~ 1/6,
TRUE ~ 0)) %>%
summarize(mean_res = mean(res))
}
, 'data.table' = {
dt[, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_pre_filter' = {
tib2 %>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
TRUE ~ 1/6)) %>%
summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10)
# Start to Finish
microbenchmark('dt_res' = {
data.table(c0
)[data.table(c1), on = .(c0 <= c1), .(c0 = x.c0, c1 = i.c1), allow.cartesian = T
][data.table(c2), on = .(c1 <= c2), .(c0 = x.c0, c1 = x.c1, c2 = i.c2), allow.cartesian = T
][c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dt_CJ_res' = {
CJ(c0, c1, c2)[c0 <= c1 & c1 <= c2, sum(ifelse(c0 < c1 & c1 < c2, 1,
ifelse(c0 == c1 & c1 < c2, 1/2, 1/6)
))
] / (length(c0) * length(c1) * length(c2))
}
, 'dplyr_res' = {
crossing(c0, c1)%>% #faster but similar in concept to non-equi
filter(c0 <= c1)%>%
crossing(c2)%>%
filter(c1 <= c2)%>%
mutate(res = case_when(c0 < c1 & c1 < c2 ~ 1,
c0 == c1 & c1 < c2 ~ 1/2,
TRUE ~ 1/6)) %>%
summarize(mean_res = sum(res)) / (length(c0) * length(c1) * length(c2))
}
, times = 10
)