为什么天真的C ++矩阵乘法比BLAS慢100倍?

时间:2013-01-25 23:52:37

标签: c++ linux matlab c++11 matrix-multiplication

我正在研究大型矩阵乘法并运行以下实验以形成基线测试:

  1. 从std normal(0 mean,1 stddev)随机生成两个4096x4096矩阵X,Y。
  2. Z = X * Y
  3. 对Z的元素求和(以确保它们被访问)并输出。
  4. 这是天真的C ++实现:

    #include <iostream>
    #include <algorithm>
    
    using namespace std;
    
    int main()
    {
        constexpr size_t dim = 4096;
    
        float* x = new float[dim*dim];
        float* y = new float[dim*dim];
        float* z = new float[dim*dim];
    
        random_device rd;
        mt19937 gen(rd());
        normal_distribution<float> dist(0, 1);
    
        for (size_t i = 0; i < dim*dim; i++)
        {
            x[i] = dist(gen);
            y[i] = dist(gen);
        }
    
        for (size_t row = 0; row < dim; row++)
            for (size_t col = 0; col < dim; col++)
            {
                float acc = 0;
    
                for (size_t k = 0; k < dim; k++)
                    acc += x[row*dim + k] * y[k*dim + col];
    
                z[row*dim + col] = acc;
            }
    
        float t = 0;
    
        for (size_t i = 0; i < dim*dim; i++)
            t += z[i];
    
        cout << t << endl;
    
        delete x;
        delete y;
        delete z;
    }
    

    编译并运行:

    $ g++ -std=gnu++11 -O3 test.cpp
    $ time ./a.out
    

    这是Octave / matlab实现:

    X = stdnormal_rnd(4096, 4096);
    Y = stdnormal_rnd(4096, 4096);
    Z = X*Y;
    sum(sum(Z))
    

    执行命令

    $ time octave < test.octave
    

    引擎盖下的Octave正在使用BLAS(我假设sgemm函数)

    Linux x86-64上的硬件是i7 3930X,内存为24 GB。 BLAS似乎使用两个核心。也许是一个超线程对?

    我发现在-O3上使用GCC 4.7编译的C ++版本需要9分钟才能执行:

    real    9m2.126s
    user    9m0.302s
    sys         0m0.052s
    

    八度音阶版花了6秒钟:

    real    0m5.985s
    user    0m10.881s
    sys         0m0.144s
    

    我明白BLAS是针对所有地狱而优化的,天真的算法完全忽略了缓存等等,但严重的是 - 90次?

    任何人都能解释这种差异吗? BLAS实现的架构究竟是什么?我看到它正在使用Fortran,但在CPU级别发生了什么?它使用什么算法?它是如何使用CPU缓存的?它调用了哪些x86-64机器指令? (它是否使用像AVX这样的高级CPU功能?)它从哪里获得了额外的速度?

    对C ++算法的哪些关键优化可以使其与BLAS版本相提并论?

    我在gdb下运行八度音,并在计算中途停了几次。它已经启动了第二个线程,这里是堆栈(所有站点看起来都相似):

    (gdb) thread 1
    #0  0x00007ffff6e17148 in pthread_join () from /lib/x86_64-linux-gnu/libpthread.so.0
    #1  0x00007ffff1626721 in ATL_join_tree () from /usr/lib/libblas.so.3
    #2  0x00007ffff1626702 in ATL_join_tree () from /usr/lib/libblas.so.3
    #3  0x00007ffff15ae357 in ATL_dptgemm () from /usr/lib/libblas.so.3
    #4  0x00007ffff1384b59 in atl_f77wrap_dgemm_ () from /usr/lib/libblas.so.3
    #5  0x00007ffff193effa in dgemm_ () from /usr/lib/libblas.so.3
    #6  0x00007ffff6049727 in xgemm(Matrix const&, Matrix const&, blas_trans_type, blas_trans_type) () from /usr/lib/x86_64-linux-gnu/liboctave.so.1
    #7  0x00007ffff6049954 in operator*(Matrix const&, Matrix const&) () from /usr/lib/x86_64-linux-gnu/liboctave.so.1
    #8  0x00007ffff7839e4e in ?? () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    #9  0x00007ffff765a93a in do_binary_op(octave_value::binary_op, octave_value const&, octave_value const&) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    #10 0x00007ffff76c4190 in tree_binary_expression::rvalue1(int) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    #11 0x00007ffff76c33a5 in tree_simple_assignment::rvalue1(int) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    #12 0x00007ffff76d0864 in tree_evaluator::visit_statement(tree_statement&) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    #13 0x00007ffff76cffae in tree_evaluator::visit_statement_list(tree_statement_list&) () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    #14 0x00007ffff757f6d4 in main_loop() () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    #15 0x00007ffff7527abf in octave_main () from /usr/lib/x86_64-linux-gnu/liboctinterp.so.1
    
    (gdb) thread 2
    #0  0x00007ffff14ba4df in ATL_dJIK56x56x56TN56x56x0_a1_b1 () from /usr/lib/libblas.so.3
    (gdb) bt
    #0  0x00007ffff14ba4df in ATL_dJIK56x56x56TN56x56x0_a1_b1 () from /usr/lib/libblas.so.3
    #1  0x00007ffff15a5fd7 in ATL_dmmIJK2 () from /usr/lib/libblas.so.3
    #2  0x00007ffff15a6ae4 in ATL_dmmIJK () from /usr/lib/libblas.so.3
    #3  0x00007ffff1518e65 in ATL_dgemm () from /usr/lib/libblas.so.3
    #4  0x00007ffff15adf7a in ATL_dptgemm0 () from /usr/lib/libblas.so.3
    #5  0x00007ffff6e15e9a in start_thread () from /lib/x86_64-linux-gnu/libpthread.so.0
    #6  0x00007ffff6b41cbd in clone () from /lib/x86_64-linux-gnu/libc.so.6
    #7  0x0000000000000000 in ?? ()
    

    按预期调用BLAS gemm

    第一个线程似乎是加入了第二个线程,所以我不确定这两个线程是否考虑了200%的CPU使用率。

    哪个库是ATL_dgemm libblas.so.3,它的代码在哪里?

    $ ls -al /usr/lib/libblas.so.3
    /usr/lib/libblas.so.3 -> /etc/alternatives/libblas.so.3
    
    $ ls -al /etc/alternatives/libblas.so.3
    /etc/alternatives/libblas.so.3 -> /usr/lib/atlas-base/atlas/libblas.so.3
    
    $ ls -al /usr/lib/atlas-base/atlas/libblas.so.3
    /usr/lib/atlas-base/atlas/libblas.so.3 -> libblas.so.3.0
    
    $ ls -al /usr/lib/atlas-base/atlas/libblas.so.3.0
    /usr/lib/atlas-base/atlas/libblas.so.3.0
    
    $ dpkg -S /usr/lib/atlas-base/atlas/libblas.so.3.0
    libatlas3-base: /usr/lib/atlas-base/atlas/libblas.so.3.0
    
    $ apt-get source libatlas3-base
    

    这是ATLAS 3.8.4

    以下是我后来实施的优化:

    使用平铺方法,我将64x64 X,Y和Z块预加载到单独的数组中。

    更改每个块的计算,以便内部循环如下所示:

    for (size_t tcol = 0; tcol < block_width; tcol++)
        bufz[trow][tcol] += B * bufy[tk][tcol];
    

    这允许GCC自动向量化到SIMD指令,并允许指令级并行(我认为)。

    开启march=corei7-avx。这增加了30%的额外速度,但是作弊因为我认为BLAS库是预先构建的。

    以下是代码:

    #include <iostream>
    #include <algorithm>
    
    using namespace std;
    
    constexpr size_t dim = 4096;
    constexpr size_t block_width = 64;
    constexpr size_t num_blocks = dim / block_width;
    
    double X[dim][dim], Y[dim][dim], Z[dim][dim];
    
    double bufx[block_width][block_width];
    double bufy[block_width][block_width];
    double bufz[block_width][block_width];
    
    void calc_block()
    {
        for (size_t trow = 0; trow < block_width; trow++)
            for (size_t tk = 0; tk < block_width; tk++)
            {
                double B = bufx[trow][tk];
    
                for (size_t tcol = 0; tcol < block_width; tcol++)
                    bufz[trow][tcol] += B * bufy[tk][tcol];
            }
    }
    
    int main()
    {
        random_device rd;
        mt19937 gen(rd());
        normal_distribution<double> dist(0, 1);
    
        for (size_t row = 0; row < dim; row++)
            for (size_t col = 0; col < dim; col++)
            {
                X[row][col] = dist(gen);
                Y[row][col] = dist(gen);
                Z[row][col] = 0;
            }
    
        for (size_t block_row = 0; block_row < num_blocks; block_row++)
            for (size_t block_col = 0; block_col < num_blocks; block_col++)
            {
                for (size_t trow = 0; trow < block_width; trow++)
                    for (size_t tcol = 0; tcol < block_width; tcol++)
                        bufz[trow][tcol] = 0;
    
                for (size_t block_k = 0; block_k < num_blocks; block_k++)
                {
                    for (size_t trow = 0; trow < block_width; trow++)
                        for (size_t tcol = 0; tcol < block_width; tcol++)
                        {
                            bufx[trow][tcol] = X[block_row*block_width + trow][block_k*block_width + tcol];
                            bufy[trow][tcol] = Y[block_k*block_width + trow][block_col*block_width + tcol];
                        }
    
                    calc_block();
                }
    
                for (size_t trow = 0; trow < block_width; trow++)
                    for (size_t tcol = 0; tcol < block_width; tcol++)
                        Z[block_row*block_width + trow][block_col*block_width + tcol] = bufz[trow][tcol];
    
            }
    
        double t = 0;
    
        for (size_t row = 0; row < dim; row++)
            for (size_t col = 0; col < dim; col++)
                t += Z[row][col];
    
        cout << t << endl;
    }
    

    所有操作都在calc_block函数中 - 超过90%的时间花在其中。

    新时间是:

    real    0m17.370s
    user    0m17.213s
    sys 0m0.092s
    

    更接近。

    calc_block函数的反编译如下:

    0000000000401460 <_Z10calc_blockv>:
      401460:   b8 e0 21 60 00          mov    $0x6021e0,%eax
      401465:   41 b8 e0 23 61 00       mov    $0x6123e0,%r8d
      40146b:   31 ff                   xor    %edi,%edi
      40146d:   49 29 c0                sub    %rax,%r8
      401470:   49 8d 34 00             lea    (%r8,%rax,1),%rsi
      401474:   48 89 f9                mov    %rdi,%rcx
      401477:   ba e0 a1 60 00          mov    $0x60a1e0,%edx
      40147c:   48 c1 e1 09             shl    $0x9,%rcx
      401480:   48 81 c1 e0 21 61 00    add    $0x6121e0,%rcx
      401487:   66 0f 1f 84 00 00 00    nopw   0x0(%rax,%rax,1)
      40148e:   00 00 
      401490:   c4 e2 7d 19 01          vbroadcastsd (%rcx),%ymm0
      401495:   48 83 c1 08             add    $0x8,%rcx
      401499:   c5 fd 59 0a             vmulpd (%rdx),%ymm0,%ymm1
      40149d:   c5 f5 58 08             vaddpd (%rax),%ymm1,%ymm1
      4014a1:   c5 fd 29 08             vmovapd %ymm1,(%rax)
      4014a5:   c5 fd 59 4a 20          vmulpd 0x20(%rdx),%ymm0,%ymm1
      4014aa:   c5 f5 58 48 20          vaddpd 0x20(%rax),%ymm1,%ymm1
      4014af:   c5 fd 29 48 20          vmovapd %ymm1,0x20(%rax)
      4014b4:   c5 fd 59 4a 40          vmulpd 0x40(%rdx),%ymm0,%ymm1
      4014b9:   c5 f5 58 48 40          vaddpd 0x40(%rax),%ymm1,%ymm1
      4014be:   c5 fd 29 48 40          vmovapd %ymm1,0x40(%rax)
      4014c3:   c5 fd 59 4a 60          vmulpd 0x60(%rdx),%ymm0,%ymm1
      4014c8:   c5 f5 58 48 60          vaddpd 0x60(%rax),%ymm1,%ymm1
      4014cd:   c5 fd 29 48 60          vmovapd %ymm1,0x60(%rax)
      4014d2:   c5 fd 59 8a 80 00 00    vmulpd 0x80(%rdx),%ymm0,%ymm1
      4014d9:   00 
      4014da:   c5 f5 58 88 80 00 00    vaddpd 0x80(%rax),%ymm1,%ymm1
      4014e1:   00 
      4014e2:   c5 fd 29 88 80 00 00    vmovapd %ymm1,0x80(%rax)
      4014e9:   00 
      4014ea:   c5 fd 59 8a a0 00 00    vmulpd 0xa0(%rdx),%ymm0,%ymm1
      4014f1:   00 
      4014f2:   c5 f5 58 88 a0 00 00    vaddpd 0xa0(%rax),%ymm1,%ymm1
      4014f9:   00 
      4014fa:   c5 fd 29 88 a0 00 00    vmovapd %ymm1,0xa0(%rax)
      401501:   00 
      401502:   c5 fd 59 8a c0 00 00    vmulpd 0xc0(%rdx),%ymm0,%ymm1
      401509:   00 
      40150a:   c5 f5 58 88 c0 00 00    vaddpd 0xc0(%rax),%ymm1,%ymm1
      401511:   00 
      401512:   c5 fd 29 88 c0 00 00    vmovapd %ymm1,0xc0(%rax)
      401519:   00 
      40151a:   c5 fd 59 8a e0 00 00    vmulpd 0xe0(%rdx),%ymm0,%ymm1
      401521:   00 
      401522:   c5 f5 58 88 e0 00 00    vaddpd 0xe0(%rax),%ymm1,%ymm1
      401529:   00 
      40152a:   c5 fd 29 88 e0 00 00    vmovapd %ymm1,0xe0(%rax)
      401531:   00 
      401532:   c5 fd 59 8a 00 01 00    vmulpd 0x100(%rdx),%ymm0,%ymm1
      401539:   00 
      40153a:   c5 f5 58 88 00 01 00    vaddpd 0x100(%rax),%ymm1,%ymm1
      401541:   00 
      401542:   c5 fd 29 88 00 01 00    vmovapd %ymm1,0x100(%rax)
      401549:   00 
      40154a:   c5 fd 59 8a 20 01 00    vmulpd 0x120(%rdx),%ymm0,%ymm1
      401551:   00 
      401552:   c5 f5 58 88 20 01 00    vaddpd 0x120(%rax),%ymm1,%ymm1
      401559:   00 
      40155a:   c5 fd 29 88 20 01 00    vmovapd %ymm1,0x120(%rax)
      401561:   00 
      401562:   c5 fd 59 8a 40 01 00    vmulpd 0x140(%rdx),%ymm0,%ymm1
      401569:   00 
      40156a:   c5 f5 58 88 40 01 00    vaddpd 0x140(%rax),%ymm1,%ymm1
      401571:   00 
      401572:   c5 fd 29 88 40 01 00    vmovapd %ymm1,0x140(%rax)
      401579:   00 
      40157a:   c5 fd 59 8a 60 01 00    vmulpd 0x160(%rdx),%ymm0,%ymm1
      401581:   00 
      401582:   c5 f5 58 88 60 01 00    vaddpd 0x160(%rax),%ymm1,%ymm1
      401589:   00 
      40158a:   c5 fd 29 88 60 01 00    vmovapd %ymm1,0x160(%rax)
      401591:   00 
      401592:   c5 fd 59 8a 80 01 00    vmulpd 0x180(%rdx),%ymm0,%ymm1
      401599:   00 
      40159a:   c5 f5 58 88 80 01 00    vaddpd 0x180(%rax),%ymm1,%ymm1
      4015a1:   00 
      4015a2:   c5 fd 29 88 80 01 00    vmovapd %ymm1,0x180(%rax)
      4015a9:   00 
      4015aa:   c5 fd 59 8a a0 01 00    vmulpd 0x1a0(%rdx),%ymm0,%ymm1
      4015b1:   00 
      4015b2:   c5 f5 58 88 a0 01 00    vaddpd 0x1a0(%rax),%ymm1,%ymm1
      4015b9:   00 
      4015ba:   c5 fd 29 88 a0 01 00    vmovapd %ymm1,0x1a0(%rax)
      4015c1:   00 
      4015c2:   c5 fd 59 8a c0 01 00    vmulpd 0x1c0(%rdx),%ymm0,%ymm1
      4015c9:   00 
      4015ca:   c5 f5 58 88 c0 01 00    vaddpd 0x1c0(%rax),%ymm1,%ymm1
      4015d1:   00 
      4015d2:   c5 fd 29 88 c0 01 00    vmovapd %ymm1,0x1c0(%rax)
      4015d9:   00 
      4015da:   c5 fd 59 82 e0 01 00    vmulpd 0x1e0(%rdx),%ymm0,%ymm0
      4015e1:   00 
      4015e2:   c5 fd 58 80 e0 01 00    vaddpd 0x1e0(%rax),%ymm0,%ymm0
      4015e9:   00 
      4015ea:   48 81 c2 00 02 00 00    add    $0x200,%rdx
      4015f1:   48 39 ce                cmp    %rcx,%rsi
      4015f4:   c5 fd 29 80 e0 01 00    vmovapd %ymm0,0x1e0(%rax)
      4015fb:   00 
      4015fc:   0f 85 8e fe ff ff       jne    401490 <_Z10calc_blockv+0x30>
      401602:   48 83 c7 01             add    $0x1,%rdi
      401606:   48 05 00 02 00 00       add    $0x200,%rax
      40160c:   48 83 ff 40             cmp    $0x40,%rdi
      401610:   0f 85 5a fe ff ff       jne    401470 <_Z10calc_blockv+0x10>
      401616:   c5 f8 77                vzeroupper 
      401619:   c3                      retq   
      40161a:   66 0f 1f 44 00 00       nopw   0x0(%rax,%rax,1)
    

5 个答案:

答案 0 :(得分:18)

以下是导致代码与BLAS性能差异的三个因素(加上Strassen算法的注释)。

在内部循环中,在k上,您有y[k*dim + col]。由于内存缓存的排列方式,具有相同kdim的连续值col映射到同一缓存集。缓存的结构方式,每个内存地址都有一个缓存集,其中的内容必须在缓存中保存。每个缓存集都有几行(四个是典型的数字),每行都可以保存映射到该特定缓存集的任何内存地址。

因为你的内部循环以这种方式迭代y,所以每次使用y中的元素时,它必须将该元素的内存加载到与前一次迭代相同的集合中。这会强制驱逐集合中之前的一个缓存行。然后,在col循环的下一次迭代中,y的所有元素都已从缓存中逐出,因此必须重新加载它们。

因此,每个时间你的循环加载一个y元素,它必须从内存加载,这需要很多CPU周期。

高性能代码以两种方式避免这种情况。一,它将工作分成更小的块。行和列被分区为较小的大小,并使用较短的循环进行处理,这些循环能够使用高速缓存行中的所有元素,并在进入下一个块之前多次使用每个元素。因此,对x元素和y元素的大多数引用都来自缓存,通常在单个处理器周期中。第二,在某些情况下,代码会将数据从矩阵的列中复制(由于几何图形而将缓存打碎)到一行临时缓冲区(避免颠簸)。这再次允许从缓存而不是从内存提供大多数内存引用。

另一个因素是使用单指令多数据(SIMD)功能。许多现代处理器都有指令在一条指令中加载多个元素(四个float元素是典型的,但有些现在做八个),存储多个元素,添加多个元素(例如,对于这四个元素中的每一个,将其添加到相应的四个中的一个),乘以多个元素,依此类推。如果您能够安排工作以使用这些说明,只需使用此类说明即可使您的代码快四倍。

这些指令不能直接在标准C中访问。一些优化器现在尽可能尝试使用这些指令,但这种优化很困难,并且从中获得很多好处并不常见。许多编译器提供了可以访问这些指令的语言的扩展。就个人而言,我通常更喜欢使用汇编语言来编写SIMD。

另一个因素是在处理器上使用指令级并行执行功能。请注意,在您的内循环中,acc已更新。在上一次迭代完成更新acc之前,下一次迭代无法添加到acc。相反,高性能代码将保持多个并行运行(甚至多个SIMD总和)。这样做的结果是,当一个和的加法正在执行时,将开始添加另一个和。在今天的处理器上,通常一次支持四个或更多浮点运算。如上所述,您的代码根本无法执行此操作。有些编译器会尝试通过重新排列循环来优化代码,但这需要编译器能够看到特定循环的迭代彼此独立,或者可以用另一个循环等进行转换。

使用缓存有效地提供十倍的性能提升是非常可行的,SIMD提供另外四个,而指令级并行提供了另外四个,共提供160个。

以下是基于this Wikipedia page的Strassen算法效果的粗略估计。维基百科页面说Strassen略好于n = 100附近的直接乘法。这表明执行时间的常数因子的比率是100 3 / 100 2.807 ≈2.4 。显然,这将根据处理器型号,与缓存效果交互的矩阵大小等而有很大差异。然而,简单的外推表明,Strassen在n = 4096((4096/100) 3-2.807 ≈2.05)时的直接乘法大约是两倍。再次,这只是一个大概的估计。

对于后来的优化,请考虑内部循环中的此代码:

bufz[trow][tcol] += B * bufy[tk][tcol];

这方面的一个潜在问题是,bufz通常会重叠bufy。由于您使用bufzbufy的全局定义,因此编译器可能知道它们在这种情况下不会重叠。但是,如果将此代码移动到作为参数传递bufzbufy的子例程中,特别是如果在单独的源文件中编译该子例程,则编译器不太可能知道{ {1}}和bufz不重叠。在这种情况下,编译器无法对代码进行向量化或重新排序,因为此迭代中的bufy可能与另一次迭代中的bufz[trow][tcol]相同。

即使编译器可以看到在当前源模块中使用不同的bufy[tk][tcol]bufz调用子例程,如果例程具有bufy链接(默认值),则编译器必须允许从外部模块调用例程,因此如果externbufz重叠,它必须生成能够正常工作的代码。 (编译器可以处理的一种方法是生成两个版本的例程,一个从外部模块调用,一个从当前正在编译的模块调用。是否这样做取决于你的编译器,优化开关,等等。)如果您将例程声明为bufy,则编译器知道无法从外部模块调用它(除非您获取其地址并且有可能将地址传递到当前模块之外)。

另一个潜在的问题是,即使编译器对此代码进行矢量化,它也不一定为您执行的处理器生成最佳代码。查看生成的汇编代码,看起来编译器只重复使用static。一遍又一遍,它将内存中的值乘以%ymm1,将内存中的值添加到%ymm1,并将%ymm1中的值存储到内存中。这有两个问题。

一,你不希望经常将这些部分和存储到内存中。您希望在寄存器中累积许多新增内容,并且寄存器将很少写入内存。说服编译器执行此操作可能需要重写代码以明确保留临时对象中的部分和,并在循环完成后将它们写入内存。

其二,这些说明名义上是连续依赖的。在乘法完成之前,添加无法启动,并且在添加完成之前,存储无法写入内存。 Core i7具有强大的乱序执行功能。因此,虽然它有等待开始执行的添加,但它会在指令流中稍后查看乘法并启动它。 (即使该乘法也使用%ymm1,处理器会动态重新映射寄存器,以便它使用不同的内部寄存器来实现此乘法。)即使您的代码充满了连续的依赖关系,处理器也会尝试一次执行几条指令。但是,有很多事情可能会干扰这一点。您可以用完处理器用于重命名的内部寄存器。您使用的内存地址可能会遇到错误的冲突。 (处理器查看十几个内存地址的低位,以查看该地址是否与它尝试从早期指令加载或存储的另一个地址相同。如果这些位相等,则处理器具有延迟当前的负载或存储,直到它可以验证整个地址是不同的。这种延迟可以比当前的负载或存储更多地充电。)因此,最好是有完全独立的指令。

这是我更喜欢在汇编中编写高性能代码的另一个原因。要在C中执行此操作,您必须说服编译器通过执行诸如编写一些您自己的SIMD代码(使用它们的语言扩展)和手动展开循环(写出多次迭代)等方式为您提供这样的指令。

复制进出缓冲区时,可能会出现类似的问题。但是,您报告90%的时间用于%ymm1,所以我没有仔细研究过这个问题。

此外,Strassen的算法是否解释了剩下的任何差异?

答案 1 :(得分:5)

Strassen算法比朴素算法有两个优点:

  1. 操作次数方面的时间复杂度更高,正如其他答案正确指出的那样;
  2. 这是cache-oblivious algorithmThe difference in number of cache misses的顺序为B*M½,其中B是缓存行大小,M是缓存大小。
  3. 我认为第二点对于您正在经历的放缓有很大影响。如果您在Linux下运行您的应用程序,我建议您使用perf工具运行它们,该工具会告诉您程序遇到的缓存未命中数。

答案 2 :(得分:2)

我不知道信息有多可靠,但是Wikipedia说BLAS使用Strassen的算法来处理大矩阵。你的确很重要。那是O(n ^ 2.807),这比你的O(n ^ 3)天真的算法好。

答案 3 :(得分:1)

这是一个非常复杂的话题,Eric在上面的帖子中得到了很好的回答。我只是想在这个方向指出一个有用的参考,第84页:

http://www.rrze.fau.de/dienste/arbeiten-rechnen/hpc/HPC4SE/

建议在阻止之上进行“循环展开和阻塞”。

  

任何人都能解释这种差异吗?

一般说明是,操作次数/数据数的比率是O(N ^ 3)/ O(N ^ 2)。因此,矩阵 - 矩阵乘法是一种缓存绑定算法,这意味着对于大矩阵大小,您不会遇到常见的内存带宽瓶颈。 如果代码经过优化,您可以获得高达90%的CPU峰值性能。因此,埃里克详细阐述的优化潜力是巨大的,正如您所观察到的那样。实际上,看到性能最佳的代码并用另一个编译器编译你的最终程序会非常有趣(英特尔通常吹嘘是最好的)。

答案 4 :(得分:-1)

大约一半的差异在算法改进中得到了解决。 (4096 * 4096)^ 3是算法的复杂度,或4.7x10 ^ 21,(4096 * 4096)^ 2.807是1x10 ^ 20。这是大约47倍的差异。

其他2x将通过更智能地使用缓存,SSE指令和其他此类低级优化来解决。

编辑:我说谎,n是宽度,而不是宽度^ 2。该算法实际上只占大约4倍,因此还有大约22倍。线程,缓存和SSE指令可能会解释这些问题。