MATLAB是否优化了诊断(A * B)?

时间:2014-05-15 14:34:12

标签: matlab matrix linear-algebra

假设我有两个非常大的矩阵A(M-by-N)和B(N-by-M)。我需要A*B的对角线。计算完全A*B需要M * M * N次乘法,而计算它的对角线只需要M * N次乘法,因为不需要计算最终会在对角线之外的元素。

MATLAB是否自动实现了这一点并实时优化diag(A*B),或者我最好在这种情况下使用for循环?

4 个答案:

答案 0 :(得分:11)

还可以将diag(A*B)实施为sum(A.*B',2)。让我们根据此问题的建议对此进行基准测试以及所有其他实现/解决方案。

下面列出了作为函数实现的不同方法,用于基准测试:

  1. 求和方法-1

    function out = sum_mult_method1(A,B)
    
    out = sum(A.*B',2);
    
  2. Sum-multiplication method-2

    function out = sum_mult_method2(A,B)
    
    out = sum(A.'.*B).';
    
  3. For-loop方法

    function out = for_loop_method(A,B)
    
    M = size(A,1);
    out = zeros(M,1);
    for i=1:M
        out(i) = A(i,:) * B(:,i);
    end
    
  4. 完整/直接乘法

    function out = direct_mult_method(A,B)
    
    out = diag(A*B);
    
  5. Bsxfun-方法

    function out = bsxfun_method(A,B)
    
    out = sum(bsxfun(@times,A,B.'),2);
    
  6. 基准代码

    num_runs = 1000;
    M_arr = [100 200 500 1000];
    N = 4;
    
    %// Warm up tic/toc.
    tic();
    elapsed = toc();
    tic();
    elapsed = toc();
    
    for k2 = 1:numel(M_arr)
        M = M_arr(k2);
    
        fprintf('\n')
        disp(strcat('*** Benchmarking sizes are M =',num2str(M),' and N = ',num2str(N)));
    
        A = randi(9,M,N);
        B = randi(9,N,M);
    
        disp('1. Sum-multiplication method-1');
        tic
        for k = 1:num_runs
            out1 = sum_mult_method1(A,B);
        end
        toc
        clear out1
    
        disp('2. Sum-multiplication method-2');
        tic
        for k = 1:num_runs
            out2 = sum_mult_method2(A,B);
        end
        toc
        clear out2
    
        disp('3. For-loop method');
        tic
        for k = 1:num_runs
            out3 = for_loop_method(A,B);
        end
        toc
        clear out3
    
        disp('4. Direct-multiplication method');
        tic
        for k = 1:num_runs
            out4 = direct_mult_method(A,B);
        end
        toc
        clear out4
    
        disp('5. Bsxfun method');
        tic
        for k = 1:num_runs
            out5 = bsxfun_method(A,B);
        end
        toc
        clear out5
    
    end
    

    <强>结果

    *** Benchmarking sizes are M =100 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.015242 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.015180 seconds.
    3. For-loop method
    Elapsed time is 0.192021 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.065543 seconds.
    5. Bsxfun method
    Elapsed time is 0.054149 seconds.
    
    *** Benchmarking sizes are M =200 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.009138 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.009428 seconds.
    3. For-loop method
    Elapsed time is 0.435735 seconds.
    4. Direct-multiplication method
    Elapsed time is 0.148908 seconds.
    5. Bsxfun method
    Elapsed time is 0.030946 seconds.
    
    *** Benchmarking sizes are M =500 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.033287 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.026405 seconds.
    3. For-loop method
    Elapsed time is 0.965260 seconds.
    4. Direct-multiplication method
    Elapsed time is 2.832855 seconds.
    5. Bsxfun method
    Elapsed time is 0.034923 seconds.
    
    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.026068 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032850 seconds.
    3. For-loop method
    Elapsed time is 1.775382 seconds.
    4. Direct-multiplication method
    Elapsed time is 13.764870 seconds.
    5. Bsxfun method
    Elapsed time is 0.044931 seconds.
    

    中级结论

    看起来sum-multiplication方法是最好的方法,但bsxfun方法似乎是在M从100增加到1000时赶上它们。

    接下来,仅使用sum-multiplicationbsxfun方法测试了更高的基准测试大小。尺寸是 -

    M_arr = [1000 2000 5000 10000 20000 50000];
    

    结果是 -

    *** Benchmarking sizes are M =1000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.030390 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.032334 seconds.
    5. Bsxfun method
    Elapsed time is 0.047377 seconds.
    
    *** Benchmarking sizes are M =2000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.040111 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.045132 seconds.
    5. Bsxfun method
    Elapsed time is 0.060762 seconds.
    
    *** Benchmarking sizes are M =5000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.099986 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.103213 seconds.
    5. Bsxfun method
    Elapsed time is 0.117650 seconds.
    
    *** Benchmarking sizes are M =10000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 0.375604 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 0.273726 seconds.
    5. Bsxfun method
    Elapsed time is 0.226791 seconds.
    
    *** Benchmarking sizes are M =20000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 1.906839 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 1.849166 seconds.
    5. Bsxfun method
    Elapsed time is 1.344905 seconds.
    
    *** Benchmarking sizes are M =50000 and N =4
    1. Sum-multiplication method-1
    Elapsed time is 5.159177 seconds.
    2. Sum-multiplication method-2
    Elapsed time is 5.081211 seconds.
    5. Bsxfun method
    Elapsed time is 3.866018 seconds.
    

    替代基准测试代码(带有`timeit)

    num_runs = 1000;
    M_arr = [1000 2000 5000 10000 20000 50000 100000 200000 500000 1000000];
    N = 4;
    
    timeall = zeros(5,numel(M_arr));
    for k2 = 1:numel(M_arr)
        M = M_arr(k2);
    
        A = rand(M,N);
        B = rand(N,M);
    
        f = @() sum_mult_method1(A,B);
        timeall(1,k2) = timeit(f);
        clear f
    
        f = @() sum_mult_method2(A,B);
        timeall(2,k2) = timeit(f);
        clear f
    
        f = @() bsxfun_method(A,B);
        timeall(5,k2) = timeit(f);
        clear f
    
    end
    
    figure,
    hold on
    plot(M_arr,timeall(1,:),'-ro')
    plot(M_arr,timeall(2,:),'-ko')
    plot(M_arr,timeall(5,:),'-.b')
    legend('sum-method1','sum-method2','bsxfun-method')
    xlabel('M ->')
    ylabel('Time(sec) ->')
    

    <强>剧情

    enter image description here

    最终结论

    似乎sum-multiplication方法在某个阶段很好,大约是M=5000标记,之后bsxfun似乎有轻微的上风。

    未来工作

    可以研究不同的N并研究这里提到的实现的性能。

答案 1 :(得分:4)

是的,这是for循环更好的罕见情况之一。

我通过探查器运行以下脚本:

M = 5000;
N = 5000;

A = rand(M, N); B = rand(N, M);
product = A*B;
diag1 = diag(product);

A = rand(M, N); B = rand(N, M);
diag2 = diag(A*B);

A = rand(M, N); B = rand(N, M);
diag3 = zeros(M,1);
for i=1:M
    diag3(i) = A(i,:) * B(:,i);
end

我在每次测试之间重置A和B,以防MATLAB试图通过缓存加速任何事情。

结果(为简洁起见编辑):

  time   calls  line
  6.29       1    5 product = A*B; 
< 0.01       1    6 diag1 = diag(product); 

  5.46       1    9 diag2 = diag(A*B); 

             1   12 diag3 = zeros(M,1); 
             1   13 for i=1:M 
  0.52    5000   14     diag3(i) = A(i,:) * B(:,i); 
< 0.01    5000   15 end 

正如我们所看到的,在这种情况下,for循环变量比其他两个变量快一个数量级。虽然diag(A*B)变体实际上比diag(product)变体更快,但它最多只是边缘。

我尝试了一些不同的M和N值,在我的测试中,只有当M = 1时,for循环变量才会变慢。

答案 2 :(得分:3)

实际上, 可以<{1}}循环使用bsxfun的奇迹更快地执行此操作:

for

这大约是我的机器上显式diag4 = sum(bsxfun(@times,A,B.'),2) 循环的两倍,对于大型矩阵(2,000乘2,000和更大),对于大于500 x 500的矩阵来说速度更快。

请注意,由于求和和乘法的顺序不同,所有这些方法都会产生数值上不同的结果。

答案 3 :(得分:3)

您只能计算没有循环的对角元素:只需使用

sum(A.'.*B).'

sum(A.*B.',2)