矩阵乘法:Strassen vs. Standard

时间:2010-11-29 14:18:39

标签: c++ performance matrix matrix-multiplication strassen

我尝试用C ++实现Strassen algorithm矩阵乘法,但结果不是我所期望的。正如您所看到的,strassen总是花费更多时间,然后标准实现,并且只有2的幂的维度与标准实现一样快。什么地方出了错? alt text

matrix mult_strassen(matrix a, matrix b) {
if (a.dim() <= cut)
    return mult_std(a, b);

matrix a11 = get_part(0, 0, a);
matrix a12 = get_part(0, 1, a);
matrix a21 = get_part(1, 0, a);
matrix a22 = get_part(1, 1, a);

matrix b11 = get_part(0, 0, b);
matrix b12 = get_part(0, 1, b);
matrix b21 = get_part(1, 0, b);
matrix b22 = get_part(1, 1, b);

matrix m1 = mult_strassen(a11 + a22, b11 + b22); 
matrix m2 = mult_strassen(a21 + a22, b11);
matrix m3 = mult_strassen(a11, b12 - b22);
matrix m4 = mult_strassen(a22, b21 - b11);
matrix m5 = mult_strassen(a11 + a12, b22);
matrix m6 = mult_strassen(a21 - a11, b11 + b12);
matrix m7 = mult_strassen(a12 - a22, b21 + b22);

matrix c(a.dim(), false, true);
set_part(0, 0, &c, m1 + m4 - m5 + m7);
set_part(0, 1, &c, m3 + m5);
set_part(1, 0, &c, m2 + m4);
set_part(1, 1, &c, m1 - m2 + m3 + m6);

return c; 
}

<小时/> 计划
matrix.h http://pastebin.com/TYFYCTY7
matrix.cpp http://pastebin.com/wYADLJ8Y
main.cpp http://pastebin.com/48BSqGJr

g++ main.cpp matrix.cpp -o matrix -O3

5 个答案:

答案 0 :(得分:8)

一些想法:

  • 您是否优化过它以考虑用零填充非功率的两个大小的矩阵?我认为该算法假设您不打扰这些术语的倍增。这就是为什么你得到的运行时间在2 ^ n和2 ^(n + 1)-1之间的平坦区域。通过不将您知道的术语乘以零,您应该能够改进这些区域。或许Strassen只能用于2 ^ n大小的矩阵。
  • 考虑一个“大”矩阵是任意的,并且该算法仅略好于天真的情况,O(N ^ 3)对O(N ^ 2.8)。在尝试更大的矩阵之前,您可能看不到可衡量的收益。例如,我做了一些有限元建模,其中10,000x10,000矩阵被认为是“小”。很难从你的图表中看出,但看起来在Stassen案例中511案例可能会更快。
  • 尝试使用各种优化级别进行测试,包括根本不进行优化。
  • 这个算法似乎假设乘法比加法要贵得多。这在40年前首次开发时确实如此,但我相信更现代的处理器,加法和乘法之间的差异变小了。这可能会降低算法的有效性,这似乎会减少乘法次数,但会增加增加量。
  • 你有没有看过其他一些Strassen实现的想法?尝试对已知的良好实施进行基准测试,以确切了解您可以获得多快的速度。

答案 1 :(得分:2)

好的我不是这个领域的专家,但是这里可能存在其他问题而不是处理速度。首先,strassen方法使用更多堆栈并具有更多函数调用,这增加了内存移动。你的堆栈越大,你就会受到一定的惩罚,因为它需要从操作系统请求更大的帧。另外,您使用动态分配,这也是一个问题。

尝试使用固定大小(带模板参数)矩阵类?这至少可以解决分配问题。

注意:我不确定该事件是否适用于您的代码。您的矩阵类使用指针但没有复制构造函数或赋值运算符。你也在最后泄漏记忆,因为你没有析构函数......

答案 2 :(得分:2)

Strassen的大O是O(N ^ log 7),而O(N ^ 3)是常规的,即log 7 base 2,略小于3。

这是你需要进行的乘法次数。

它假设你没有任何其他成本,并且也应该“更快”,因为N足够大,而你可能没有。

你的大部分实现都是创建了很多子矩阵,而我的猜测是你存储它们的方式,你每次执行此操作都需要分配内存和复制。有一些“切片”矩阵和逻辑转置矩阵,如果你可以帮助你优化可能是你的过程中最慢的部分。

答案 3 :(得分:1)

我对Stassen倍增速度的速度感到震惊 实施是:

http://ezekiel.vancouver.wsu.edu/~cs330/lectures/linear_algebra/mm/mm.c

当n = 1024时,我的机器上的速度提高了近16倍。 我可以解释这个加速的唯一方法是 我的算法更适合缓存 - 即,它专注于小型 矩阵的各个部分因此数据更加本地化。

C ++实现的开销可能太高了 - 编译器生成的临时数比实际需要的更多。 我的实现试图通过重用内存来最小化这个 可能的。

答案 4 :(得分:-1)

远射,但你认为标准乘法可能会被编译器优化吗?你可以关闭优化吗​​?