以有效方式对矩阵项的乘积求和

时间:2018-09-21 09:28:06

标签: r performance matrix sum

假设A是一个带有实项的n x n对称矩阵。我想计算A[u,t]*A[t,s]*A[s,u]s,t,u1n的总和。一种简单的方法如下。

n<-5
A<-matrix(sample(1:n^2),n)
A<-A%*%t(A)
isSymmetric(A)
S1<-0
for (s in 1:n)
    {
        for (t in 1:n)
            {
                for (u in 1:n)
                    {
                        S1<-S1+A[u,t]*A[t,s]*A[s,u]
                    }
            }
     }
print(S1)

但是,这是缓慢且效率低下的。我想出了以下更有效的代码。

S2<-0
for (s in 1:n)
    {
        S2<-S2+sum(t(A*A[,s])*A[,s])
    }
print(S2)
S1==S2

是否可以进一步改进此代码,以使我们不必完全使用循环?

1 个答案:

答案 0 :(得分:4)

尝试一下:

sum(A * A %*% t(A))

有关F.Prives注释,可以测试不同的方法:

set.seed(42)
n <- 10
A <- matrix(sample(1:n^2), n)
A <- A %*% t(A)
require(Matrix)
X <- forceSymmetric(A)

m1 <- sum(A * A %*% t(A))
m3 <- sum(X * X %*% t(X))

all.equal(m1, m3)
# [1] TRUE


bench::mark(sum(A * A %*% t(A)),
            sum(X * X %*% t(X)), check = F, relative = T)[, 1:10]
# # A tibble: 4 x 10
# expression                     min     mean   median      max `itr/sec` mem_alloc  n_gc n_itr total_time
# <chr>                     <bch:tm> <bch:tm> <bch:tm> <bch:tm>     <dbl> <bch:byt> <dbl> <int>   <bch:tm>
# 1 sum(A * A %*% t(A))           12us  17.26us  13.26us    334us    57929.    1.66KB     1  9999      173ms
# 3 sum(X * X %*% t(X))            1ms   1.43ms   1.16ms     41ms      701.    5.28KB     1   278      397ms

对于小型矩阵,基本矩阵看起来更快。

对于n <- 1000

# A tibble: 4 x 10
# expression                     min     mean   median      max `itr/sec` mem_alloc  n_gc n_itr total_time
# <chr>                     <bch:tm> <bch:tm> <bch:tm> <bch:tm>     <dbl> <bch:byt> <dbl> <int>   <bch:tm>
# 1 sum(A * A %*% t(A))          659ms    695ms    694ms    731ms      1.44    15.3MB     0     5      3.47s
# 3 sum(X * X %*% t(X))          708ms    749ms    759ms    774ms      1.34    45.8MB     0     5      3.74s

另外,基础速度要快一点。

p.s。

# A tibble: 6 x 10
  expression                     min     mean   median      max `itr/sec` mem_alloc  n_gc n_itr total_time
  <chr>                     <bch:tm> <bch:tm> <bch:tm> <bch:tm>     <dbl> <bch:byt> <dbl> <int>   <bch:tm>
1 sum(A * A %*% t(A))          673ms    769ms    714ms    894ms      1.30    15.3MB     0     5      3.84s
3 sum(X * X %*% t(X))          710ms    721ms    716ms    745ms      1.39    45.8MB     0     5       3.6s
5 sum(tcrossprod(A) * A)       399ms    407ms    403ms    418ms      2.46    15.3MB     0     5      2.03s
6 sum(tcrossprod(X) * X)       402ms    423ms    424ms    436ms      2.37    30.6MB     0     5      2.11s

sum(tcrossprod(A) * A)会更快,并给出相同的结果