假设我有一个如下数据框:
> foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
> foo
x id
1 1 1
2 2 1
3 3 2
4 4 2
5 5 2
6 6 3
7 7 3
8 8 3
9 9 3
我想要一个非常有效的h(a,b)实现,它计算属于同一个id类的xi,xj的所有(a - xi)*(b - xj)。例如,我当前的实现是
h(a, b, foo){
a.diff = a - foo$x
b.diff = b - foo$x
prod = a.diff%*%t(b.diff)
id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + diag(nrow(foo))
return(sum(prod*id.indicator))
}
例如,使用(a,b)=(0,1),这是函数中每一步的输出
> a.diff
[1] -1 -2 -3 -4 -5 -6 -7 -8 -9
> b.diff
[1] 0 -1 -2 -3 -4 -5 -6 -7 -8
> prod
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9]
[1,] 0 1 2 3 4 5 6 7 8
[2,] 0 2 4 6 8 10 12 14 16
[3,] 0 3 6 9 12 15 18 21 24
[4,] 0 4 8 12 16 20 24 28 32
[5,] 0 5 10 15 20 25 30 35 40
[6,] 0 6 12 18 24 30 36 42 48
[7,] 0 7 14 21 28 35 42 49 56
[8,] 0 8 16 24 32 40 48 56 64
[9,] 0 9 18 27 36 45 54 63 72
> id.indicator
1 2 3 4 5 6 7 8 9
1 1 1 0 0 0 0 0 0 0
2 1 1 0 0 0 0 0 0 0
3 0 0 1 1 1 0 0 0 0
4 0 0 1 1 1 0 0 0 0
5 0 0 1 1 1 0 0 0 0
6 0 0 0 0 0 1 1 1 1
7 0 0 0 0 0 1 1 1 1
8 0 0 0 0 0 1 1 1 1
9 0 0 0 0 0 1 1 1 1
实际上,最多可以有1000个id簇,每个簇至少有40个,这使得这个方法效率太低,因为id.indicator中的稀疏条目和off-block-diagonals上的prod中的额外计算它不会被使用。
答案 0 :(得分:6)
我打了一个圆。首先,您的实施:
foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
h <- function(a, b, foo){
a.diff = a - foo$x
b.diff = b - foo$x
prod = a.diff%*%t(b.diff)
id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) +
diag(nrow(foo))
return(sum(prod*id.indicator))
}
h(a = 1, b = 0, foo = foo)
#[1] 891
接下来,我尝试使用适当的稀疏矩阵实现(通过Matrix
包)和索引矩阵函数的变体。我也使用tcrossprod
,我经常发现它比a %*% t(b)
快一点。
library("Matrix")
h2 <- function(a, b, foo) {
a.diff <- a - foo$x
b.diff <- b - foo$x
prod <- tcrossprod(a.diff, b.diff) # the same as a.diff%*%t(b.diff)
id.indicator <- do.call(bdiag, lapply(table(foo$id), function(n) matrix(1,n,n)))
return(sum(prod*id.indicator))
}
h2(a = 1, b = 0, foo = foo)
#[1] 891
请注意,此函数依赖foo$id
进行排序。
最后,我尝试避免创建完整的n×n矩阵。
h3 <- function(a, b, foo) {
a.diff <- a - foo$x
b.diff <- b - foo$x
ids <- unique(foo$id)
res <- 0
for (i in seq_along(ids)) {
indx <- which(foo$id == ids[i])
res <- res + sum(tcrossprod(a.diff[indx], b.diff[indx]))
}
return(res)
}
h3(a = 1, b = 0, foo = foo)
#[1] 891
对您的示例进行基准测试:
library("microbenchmark")
microbenchmark(h(a = 1, b = 0, foo = foo),
h2(a = 1, b = 0, foo = foo),
h3(a = 1, b = 0, foo = foo))
# Unit: microseconds
# expr min lq mean median uq max neval
# h(a = 1, b = 0, foo = foo) 248.569 261.9530 493.2326 279.3530 298.2825 21267.890 100
# h2(a = 1, b = 0, foo = foo) 4793.546 4893.3550 5244.7925 5051.2915 5386.2855 8375.607 100
# h3(a = 1, b = 0, foo = foo) 213.386 227.1535 243.1576 234.6105 248.3775 334.612 100
现在,在此示例中,h3
最快,h2
非常慢。但我想对于更大的例子来说两者都会更快。尽管如此,h3
可能会赢得更大的例子。虽然有足够的空间进行更多优化,但h3
应该更快,内存效率更高。所以,我认为你应该选择h3
的变体,它不会产生不必要的大矩阵。
答案 1 :(得分:3)
sum(sapply(split(foo, foo$id), function(d) sum(outer(a-d$x, b-d$x))))
#[1] 891
#TESTING
foo = data.frame(x = sample(1:9,10000,replace = TRUE),
id = sample(1:3, 10000, replace = TRUE))
system.time(sum(sapply(split(foo, foo$id), function(d) sum(outer(a-d$x, b-d$x)))))
# user system elapsed
# 0.15 0.01 0.17
答案 2 :(得分:3)
tapply
允许您跨矢量组应用函数,并且如果可以,将结果简化为矩阵或矢量。使用tcrossprod
乘以每个组的所有组合,以及一些适当大的数据,它表现良好:
# setup
set.seed(47)
foo = data.frame(x = 1:9, id = c(1, 1, 2, 2, 2, 3, 3, 3, 3))
foo2 <- data.frame(id = sample(1000, 40000, TRUE), x = rnorm(40000))
h_OP <- function(a, b, foo){
a.diff = a - foo$x
b.diff = b - foo$x
prod = a.diff %*% t(b.diff)
id.indicator = as.matrix(ifelse(dist(foo$id, diag = T, upper = T),0,1)) + diag(nrow(foo))
return(sum(prod * id.indicator))
}
h3_AEBilgrau <- function(a, b, foo) {
a.diff <- a - foo$x
b.diff <- b - foo$x
ids <- unique(foo$id)
res <- 0
for (i in seq_along(ids)) {
indx <- which(foo$id == ids[i])
res <- res + sum(tcrossprod(a.diff[indx], b.diff[indx]))
}
return(res)
}
h_d.b <- function(a, b, foo){
sum(sapply(split(foo, foo$id), function(d) sum(outer(a-d$x, b-d$x))))
}
h_alistaire <- function(a, b, foo){
sum(tapply(foo$x, foo$id, function(x){sum(tcrossprod(a - x, b - x))}))
}
所有返回相同的东西,并且在小数据上没有那么不同:
h_OP(0, 1, foo)
#> [1] 891
h3_AEBilgrau(0, 1, foo)
#> [1] 891
h_d.b(0, 1, foo)
#> [1] 891
h_alistaire(0, 1, foo)
#> [1] 891
# small data test
microbenchmark::microbenchmark(
h_OP(0, 1, foo),
h3_AEBilgrau(0, 1, foo),
h_d.b(0, 1, foo),
h_alistaire(0, 1, foo)
)
#> Unit: microseconds
#> expr min lq mean median uq max neval cld
#> h_OP(0, 1, foo) 143.749 157.8895 189.5092 189.7235 214.3115 262.258 100 b
#> h3_AEBilgrau(0, 1, foo) 80.970 93.8195 112.0045 106.9285 125.9835 225.855 100 a
#> h_d.b(0, 1, foo) 355.084 381.0385 467.3812 437.5135 516.8630 2056.972 100 c
#> h_alistaire(0, 1, foo) 148.735 165.1360 194.7361 189.9140 216.7810 287.990 100 b
在更大的数据上,差异变得更加明显。原件可能会让我的笔记本电脑崩溃,但这里有最快的两个基准:
# on 1k groups, 40k rows
microbenchmark::microbenchmark(
h3_AEBilgrau(0, 1, foo2),
h_alistaire(0, 1, foo2)
)
#> Unit: milliseconds
#> expr min lq mean median uq max neval cld
#> h3_AEBilgrau(0, 1, foo2) 336.98199 403.04104 412.06778 410.52391 423.33008 443.8286 100 b
#> h_alistaire(0, 1, foo2) 14.00472 16.25852 18.07865 17.22296 18.09425 96.9157 100 a
另一种可能性是使用data.frame按组进行汇总,然后对相应的列求和。在基础R中,你用aggregate
做到这一点,但dplyr和data.table很受欢迎,因为这种方法更简单,聚合更复杂。
aggregate
比tapply
慢。 dplyr比aggregate
快,但仍然较慢。 data.table,专为速度而设计,几乎与tapply
一样快。
library(dplyr)
library(data.table)
h_aggregate <- function(a, b, foo){sum(aggregate(x ~ id, foo, function(x){sum(tcrossprod(a - x, b - x))})$x)}
tidy_h <- function(a, b, foo){foo %>% group_by(id) %>% summarise(x = sum(tcrossprod(a - x, b - x))) %>% select(x) %>% sum()}
h_dt <- function(a, b, foo){setDT(foo)[, .(x = sum(tcrossprod(a - x, b - x))), by = id][, sum(x)]}
microbenchmark::microbenchmark(
h_alistaire(1, 0, foo2),
h_aggregate(1, 0, foo2),
tidy_h(1, 0, foo2),
h_dt(1, 0, foo2)
)
#> Unit: milliseconds
#> expr min lq mean median uq max neval cld
#> h_alistaire(1, 0, foo2) 13.30518 15.52003 18.64940 16.48818 18.13686 62.35675 100 a
#> h_aggregate(1, 0, foo2) 93.08401 96.61465 107.14391 99.16724 107.51852 143.16473 100 c
#> tidy_h(1, 0, foo2) 39.47244 42.22901 45.05550 43.94508 45.90303 90.91765 100 b
#> h_dt(1, 0, foo2) 13.31817 15.09805 17.27085 16.46967 17.51346 56.34200 100 a