加快二次形式的评价

时间:2014-12-11 13:49:09

标签: r for-loop vectorization

我的问题是另一个" Vectorize this!"。类似的问题出现在其他地方(Efficient way of calculating quadratic forms: avoid for loops?),但不知何故,我似乎无法使其适用于我的案例。

我想为大小为x'Sx的样本中的每个p维度观察x计算二次形式n。我无法弄清楚一个漂亮的矢量化代码,所以我最后的选择是for loop。以下玩具示例适用于p=2n=100

set.seed(123)
n <- 100
x1 <- rnorm(n)
x2 <- rnorm(n)
x <- cbind(x1,x2)
Sigma <- matrix(c(1, 2, 3, 4), ncol = 2)
z  <- rep(0, n)
for (i in 1:n) {
   z[i]  <- x[i, ] %*% solve(Sigma, x[i, ]) #quadratic form of x'S^{-1}x
}

与崇拜矢量化代码的许多其他R用户一样,for循环的使用引起了情绪上的痛苦。为了减轻痛苦,我使用一些常见的矢量化技术修改了我的代码。

ap <- function(Sigma, x) apply(x, 1, function(x) x %*% solve(Sigma, x))
lap <- function(Sigma, x) unlist(lapply(1:n, function(i) x[i, ] %*% solve(Sigma, x[i, ])))
loop <- function(Sigma, x){
  z  <- rep(0, n)
  for (i in 1:n) {
    z[i]  <- x[i, ] %*% solve(Sigma, x[i, ])
  }
  z
}

但速度比较显示没有太多收获。

library(microbenchmark)
microbenchmark(lap(Sigma, x), ap(Sigma, x), loop(Sigma, x))

# Unit: milliseconds
#           expr      min       lq     mean   median       uq       max neval
#  lap(Sigma, x) 4.207434 4.444895 5.092389 4.616912 5.283504  8.440802   100
#   ap(Sigma, x) 4.360204 4.523306 5.317304 4.685396 5.412771 10.168674   100
# loop(Sigma, x) 4.518645 4.679317 6.204626 4.827831 5.438908 94.115144   100

是否还有改进的余地,或者我应该去Rcpp让自己摆脱使用for loops的罪恶?

3 个答案:

答案 0 :(得分:1)

如果您将x行存储在列表中并使用vapply而不是lapply,则可以将其加速,如下所示

# First, make a list of the rows of x
xl <- vector("list",nrow(x))
for (i in seq_along(xl)) xl[[i]] <- x[i, ] 

# Apply solve
solve.mat <- vapply(xl, solve, numeric(2), a = Sigma)
# Take the dot product of each pair of elements
result <- colSums(solve.mat * t(x))
all(result == lap(Sigma, x))
# [1] TRUE

一步编写并比较

library(microbenchmark)
microbenchmark(lap = lap(Sigma, x),
    csums = colSums(vapply(xl, solve, numeric(2), a = Sigma) * t(x)))
# Unit: milliseconds
#   expr      min       lq     mean   median       uq      max neval
#    lap 3.013343 3.050855 3.164558 3.097901 3.136355 4.206923   100
#  csums 2.224350 2.263772 2.354349 2.289751 2.317672 3.660294   100

答案 1 :(得分:0)

非常感谢您关注@konvas,@ beginneR。

我错误地计算x'S^{-1}x而@Ben Bolker在上面的链接中给出的答案计算x'SxcolSums(t(x) * (solve(Sigma) %*% t(x)))完全解决了我的问题。

我现在更加忠诚于矢量化代码。

microbenchmark(lap(Sigma, x), 
               product = diag(x %*% solve(Sigma) %*%t(x)), 
               colsum = colSums(t(x) * (solve(Sigma) %*% t(x)))
               )
#Unit: microseconds
#          expr      min        lq      mean    median        uq      max neval
# lap(Sigma, x) 4283.616 4384.9215 4961.9761 4475.6920 4700.3885 9472.096   100
#       product  126.835  134.3315  165.3030  144.0575  199.5725  306.349   100
#        colsum   92.391  102.9265  130.0202  110.8285  146.2855  346.061   100

答案 2 :(得分:0)

以下在您的示例中并非如此,但在许多其他情况下可能如此。假设你的矩阵是正定的。例如。让Sigma <- matrix(c(1, 1, 1, 4), ncol = 2)。然后,您可以在计算 Cholesky 分解后反向求解,这在数值上是一个更好的主意,在这种情况下也更快:

lap <- function(Sigma, x) unlist(lapply(1:n, function(i) x[i, ] %*% solve(Sigma, x[i, ])))

bench::mark(lap = lap(Sigma, x),
            product = diag(x %*% solve(Sigma) %*%t(x)),
            `product smarter` = diag(x %*% tcrossprod(solve(Sigma), x)),
            colsum = colSums(t(x) * (solve(Sigma) %*% t(x))),
            `colsum (smarter)` = rowSums(x %*% solve(Sigma) * x),
            backsolve = {
              ch <- chol(Sigma)
              colSums(backsolve(ch, t(x), transpose = TRUE)^2)
            })
#R> # A tibble: 6 x 13
#R>   expression            min  median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result  memory     time     gc       
#R>   <bch:expr>       <bch:tm> <bch:t>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>  <list>     <list>   <list>   
#R> 1 lap               800.1µs   878µs     1113.    1.66KB     12.7   527     6      474ms <dbl [… <Rprofmem… <bch:tm… <tibble …
#R> 2 product            30.4µs    32µs    29895.   83.92KB     53.9  9982    18      334ms <dbl [… <Rprofmem… <bch:tm… <tibble …
#R> 3 product smarter    28.2µs    30µs    32300.   82.31KB     61.5  9981    19      309ms <dbl [… <Rprofmem… <bch:tm… <tibble …
#R> 4 colsum             20.5µs  22.3µs    44005.    5.66KB     13.2  9997     3      227ms <dbl [… <Rprofmem… <bch:tm… <tibble …
#R> 5 colsum (smarter)   15.1µs  16.2µs    60101.    4.05KB     12.0  9998     2      166ms <dbl [… <Rprofmem… <bch:tm… <tibble …
#R> 6 backsolve            15µs  16.3µs    59162.    5.66KB     11.8  9998     2      169ms <dbl [… <Rprofmem… <bch:tm… <tibble …

还要注意 colSums 版本的更快版本发布了 in this answer,其中避免了转置调用。