Matlab中是否有一种有效的方法只计算3(或更多)矩阵乘积的对角线?具体来说我想要
diag(A'*B*A)
当A和B都非常大时,可能需要很长时间。如果只有两个矩阵
diag(B*A)
然后我可以这样快速地做到这一点:
sum(B.*A',2)
所以现在我用这样的3个矩阵计算对角线:
C = B*A;
ans = sum(A'.*C',2);
这有很大帮助,但第一次操作(C = B * A)仍然需要很长时间。整个事情也必须重复多次,导致我的代码运行数周。例如,B约为15k×15k,A约为32k×15k。没有什么是稀疏的。
答案 0 :(得分:2)
首先,欢迎!说实话,这似乎很难。稍微改变至少会略微提高速度:
N = 5000;
A = rand(N,N*2);
B = rand(N,N);
t = cputime;
diag(A'*B*A);
disp(['Elapsed cputime ' num2str(cputime-t)]);
t=cputime;
C = B*A;
sum(A'.*C',2);
disp(['Elapsed cputime ' num2str(cputime-t)]);
% slightly better...
t=cputime;
C = B*A;
sum(A.*C)';
disp(['Elapsed cputime ' num2str(cputime-t)]);
% slightly better than slightly better...
t=cputime;
sum(A.*(B*A))';
disp(['Elapsed cputime ' num2str(cputime-t)]);
结果:
Elapsed cputime 82.2593
Elapsed cputime 28.6106
Elapsed cputime 25.8338
Elapsed cputime 25.7714