Strassen算法不是最快的?

时间:2011-06-06 13:52:16

标签: algorithm strassen

我从某处复制了strassen的算法,然后执行了它。这是输出

n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms

其中strassen1是动态方法,缓存strassen2classical是旧矩阵乘法。这意味着我们古老而简单的经典之作是最好的。这是真的还是我错了?这是Java中的代码。

import java.util.Random;

class TestIntMatrixMultiplication {

    public static void main (String...args) throws Exception {
        final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
        final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
        final Random random = new Random(seed);

        int[][] a, b, c;

        a = new int[n][n];
        b = new int[n][n];
        c = new int[n][n];

        for(int i=0; i<n; i++) {
            for(int j=0; j<n; j++) {
                a[i][j] = random.nextInt(100);
                b[i][j] = random.nextInt(100);
            }
        }



        System.out.println("n = " + n);

        if (a.length < 64) {
            System.out.println("A");
            dumpMatrix(a);
            System.out.println("B");
            dumpMatrix(b);
            System.out.println("classic");
            Classical.mult(c, a, b);
            dumpMatrix(c);
            System.out.println("strassen");
            strassen2.mult(c, a, b);
            dumpMatrix(c);

            return;
        }

        for (int i = 0; i <3; ++i) {
            timeMultiplies1(a, b, c);
            if (n <= 256)
                timeMultiplies2( a, b, c);
            timeMultiplies3( a, b, c);
        }
    }

    static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        Classical.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("classical took " + (finish - start) + "ms");
    }
    static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen1.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen 1 took " + (finish - start) + "ms");
    }
    static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
        final long start = System.currentTimeMillis();
        strassen2.mult(c, a, b);
        final long finish = System.currentTimeMillis();
        System.out.println("strassen2 took " + (finish - start) + "ms");
    }

    static void dumpMatrix (int[][] m) {
        for (int[] row : m) {
            System.out.print("[\t");
            for (int val : row) {
                System.out.print(val);
                System.out.print('\t');
            }
            System.out.println(']');
        }
    }
}

class strassen1{

    public String getName () {
        return "Strassen(dynamic)";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        return strassenMatrixMultiplication(a, b);
    }

    public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        if(n == 1) {
            result[0][0] = A[0][0] * B[0][0];
        } else {
            int [][] A11 = new int[n/2][n/2];
            int [][] A12 = new int[n/2][n/2];
            int [][] A21 = new int[n/2][n/2];
            int [][] A22 = new int[n/2][n/2];

            int [][] B11 = new int[n/2][n/2];
            int [][] B12 = new int[n/2][n/2];
            int [][] B21 = new int[n/2][n/2];
            int [][] B22 = new int[n/2][n/2];

            divideArray(A, A11, 0 , 0);
            divideArray(A, A12, 0 , n/2);
            divideArray(A, A21, n/2, 0);
            divideArray(A, A22, n/2, n/2);

            divideArray(B, B11, 0 , 0);
            divideArray(B, B12, 0 , n/2);
            divideArray(B, B21, n/2, 0);
            divideArray(B, B22, n/2, n/2);

            int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
            int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
            int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
            int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
            int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
            int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
            int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));

            int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
            int [][] C12 = addMatrices(P3, P5);
            int [][] C21 = addMatrices(P2, P4);
            int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);

            copySubArray(C11, result, 0 , 0);
            copySubArray(C12, result, 0 , n/2);
            copySubArray(C21, result, n/2, 0);
            copySubArray(C22, result, n/2, n/2);
        }

        return result;
    }

    public static int [][] addMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
        result[i][j] = A[i][j] + B[i][j];

        return result;
    }

    public static int [][] subtractMatrices(int [][] A, int [][] B) {
        int n = A.length;

        int [][] result = new int[n][n];

        for(int i=0; i<n; i++)
            for(int j=0; j<n; j++)
                result[i][j] = A[i][j] - B[i][j];

        return result;
    }

    public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                child[i1][j1] = parent[i2][j2];
    }

    public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
        for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
            for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
                parent[i2][j2] = child[i1][j1];
    }
}
class strassen2{

    public String getName () {
        return "Strassen(cached)";
    }

    static int [][] p1;
    static int [][] p2;
    static int [][] p3;
    static int [][] p4;
    static int [][] p5;
    static int [][] p6;
    static int [][] p7;
    static int [][] t0;
    static int [][] t1;

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        final int n = c.length;

        if (p1 == null || p1.length < n) {
            p1 = new int[n/2][n-1];
            p2 = new int[n/2][n-1];
            p3 = new int[n/2][n-1];
            p4 = new int[n/2][n-1];
            p5 = new int[n/2][n-1];
            p6 = new int[n/2][n-1];
            p7 = new int[n/2][n-1];
            t0 = new int[n/2][n-1];
            t1 = new int[n/2][n-1];
        }

        mult(c, a, b, 0, 0, n, 0);

        return c;
    }

    public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
        if(n == 1) {
            c[i0][j0] = a[i0][j0] * b[i0][j0];
        } else {
            final int nBy2 = n/2;

            final int i1 = i0 + nBy2;
            final int j1 = j0 + nBy2;

            // offset applied to 'p' j index so recursive calls don't overwrite data
            final int jp0 = offs;
            final int jp1 = nBy2 + offs;

            // P1 <- (A11 + A22)(B11 + B22)
            //  T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P2 <- (A21 + A22)B11
            //  T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0];
                    }
            }

            mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P3 <- A11(B12 - B22)
            //  T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
                }
            }

            mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P4 <- A22(B21 - B11)
            //  T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
                }
            }

            mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P5 <- (A11 + A12) B22
            //  T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j1];
                }
            }

            mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P6 <- (A21 - A11)(B11 - B12)
            //  T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
                    t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
                }
            }

            mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // P7 <- (A12 - A22)(B21 + B22)
            //  T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
                    t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
                }
            }

            mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);

            // combine
            for (int i = 0; i < nBy2; ++i) {
                for (int j = 0; j < nBy2; ++j) {
                    // C11 = P1 + P4 - P5 + P7;
                    c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
                    // C12 = P3 + P5;
                    c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
                    // C21 = P2 + P4;
                    c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
                    // C22 = P1 + P3 - P2 + P6;
                    c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
                }
            }
        }
    }

    void dumpInternal () {
        System.out.println("P1");
        TestIntMatrixMultiplication.dumpMatrix(p1);
        System.out.println("P2");
        TestIntMatrixMultiplication.dumpMatrix(p2);
        System.out.println("P3");
        TestIntMatrixMultiplication.dumpMatrix(p3);
        System.out.println("P4");
        TestIntMatrixMultiplication.dumpMatrix(p4);
        System.out.println("P5");
        TestIntMatrixMultiplication.dumpMatrix(p5);
        System.out.println("P6");
        TestIntMatrixMultiplication.dumpMatrix(p6);
        System.out.println("P7");
        TestIntMatrixMultiplication.dumpMatrix(p7);
        System.out.println("T0");
        TestIntMatrixMultiplication.dumpMatrix(t0);
        System.out.println("T1");
        TestIntMatrixMultiplication.dumpMatrix(t1);
    }
}


class Classical{
    public String getName () {
        return "classic";
    }

    public static int[][] mult (int[][] c, int[][] a, int[][] b) {
        int n = a.length;

        for(int i=0; i<n; i++) {
            final int[] a_i = a[i];
            final int[] c_i = c[i];

            for(int j=0; j<n; j++) {
                int sum = 0;

                for(int k=0; k<n; k++) {
                    sum += a_i[k] * b[k][j];
                }

                c_i[j] = sum;
            }
        }

        return c;
    }
}

3 个答案:

答案 0 :(得分:5)

我看到的问题:

1)你的Strassen乘法是一直动态分配内存。这会破坏性能。

2)你的Strassen乘法应该切换到小尺寸的常规乘法,而不是一直向下递归(尽管这种优化排序会使你的测试无效)。

3)你的矩阵大小可能太小而无法看出差异。

您应该与几种不同的尺码进行比较。也许256,512,1024,2048,4096,8192 ......然后绘制时间并观察趋势。如果它的所有幂都是2,那么您可能希望在对数刻度上使用矩阵大小。

Strassen对大N来说速度更快。多大程度上取决于实施。你为经典做的只是一个基本的实现,在现代机器上也不是最佳的。

答案 1 :(得分:2)

除了实现问题,我认为你误解了算法的性能。就像phkahler所说的那样,你对算法性能的期望有点偏差。分治算法适用于大型输入,因为它们递归地将问题分解为可以更快地解决的子问题。

但是,与此拆分操作相关的开销可能导致算法对于小型甚至中型输入运行(有时更慢)。通常,像Strassen这样的算法的理论分析将包括所谓的“断点”计算。这是输入大小,其中分裂的开销变得优于天真的技术。

您的代码需要检查在断点处切换到天真技术的输入大小。

答案 2 :(得分:1)

记下Strassen算法对2 x 2矩阵的作用。统计操作。这个数字绝对荒谬。使用Strassen的方法处理2x2矩阵是愚蠢的。对于3 x 3或4 x 4矩阵也是如此,可能还有很长的路要走。