我从某处复制了strassen的算法,然后执行了它。这是输出
n = 256
classical took 360ms
strassen 1 took 33609ms
strassen2 took 1172ms
classical took 437ms
strassen 1 took 32891ms
strassen2 took 1156ms
classical took 266ms
strassen 1 took 27234ms
strassen2 took 734ms
其中strassen1
是动态方法,缓存strassen2
和classical
是旧矩阵乘法。这意味着我们古老而简单的经典之作是最好的。这是真的还是我错了?这是Java中的代码。
import java.util.Random;
class TestIntMatrixMultiplication {
public static void main (String...args) throws Exception {
final int n = args.length > 0 ? Integer.parseInt(args[0]) : 256;
final int seed = args.length > 1 ? Integer.parseInt(args[1]) : 256;
final Random random = new Random(seed);
int[][] a, b, c;
a = new int[n][n];
b = new int[n][n];
c = new int[n][n];
for(int i=0; i<n; i++) {
for(int j=0; j<n; j++) {
a[i][j] = random.nextInt(100);
b[i][j] = random.nextInt(100);
}
}
System.out.println("n = " + n);
if (a.length < 64) {
System.out.println("A");
dumpMatrix(a);
System.out.println("B");
dumpMatrix(b);
System.out.println("classic");
Classical.mult(c, a, b);
dumpMatrix(c);
System.out.println("strassen");
strassen2.mult(c, a, b);
dumpMatrix(c);
return;
}
for (int i = 0; i <3; ++i) {
timeMultiplies1(a, b, c);
if (n <= 256)
timeMultiplies2( a, b, c);
timeMultiplies3( a, b, c);
}
}
static void timeMultiplies1 (int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
Classical.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("classical took " + (finish - start) + "ms");
}
static void timeMultiplies2(int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
strassen1.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("strassen 1 took " + (finish - start) + "ms");
}
static void timeMultiplies3 (int[][] a, int[][] b, int[][] c) {
final long start = System.currentTimeMillis();
strassen2.mult(c, a, b);
final long finish = System.currentTimeMillis();
System.out.println("strassen2 took " + (finish - start) + "ms");
}
static void dumpMatrix (int[][] m) {
for (int[] row : m) {
System.out.print("[\t");
for (int val : row) {
System.out.print(val);
System.out.print('\t');
}
System.out.println(']');
}
}
}
class strassen1{
public String getName () {
return "Strassen(dynamic)";
}
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
return strassenMatrixMultiplication(a, b);
}
public static int [][] strassenMatrixMultiplication(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
if(n == 1) {
result[0][0] = A[0][0] * B[0][0];
} else {
int [][] A11 = new int[n/2][n/2];
int [][] A12 = new int[n/2][n/2];
int [][] A21 = new int[n/2][n/2];
int [][] A22 = new int[n/2][n/2];
int [][] B11 = new int[n/2][n/2];
int [][] B12 = new int[n/2][n/2];
int [][] B21 = new int[n/2][n/2];
int [][] B22 = new int[n/2][n/2];
divideArray(A, A11, 0 , 0);
divideArray(A, A12, 0 , n/2);
divideArray(A, A21, n/2, 0);
divideArray(A, A22, n/2, n/2);
divideArray(B, B11, 0 , 0);
divideArray(B, B12, 0 , n/2);
divideArray(B, B21, n/2, 0);
divideArray(B, B22, n/2, n/2);
int [][] P1 = strassenMatrixMultiplication(addMatrices(A11, A22), addMatrices(B11, B22));
int [][] P2 = strassenMatrixMultiplication(addMatrices(A21, A22), B11);
int [][] P3 = strassenMatrixMultiplication(A11, subtractMatrices(B12, B22));
int [][] P4 = strassenMatrixMultiplication(A22, subtractMatrices(B21, B11));
int [][] P5 = strassenMatrixMultiplication(addMatrices(A11, A12), B22);
int [][] P6 = strassenMatrixMultiplication(subtractMatrices(A21, A11), addMatrices(B11, B12));
int [][] P7 = strassenMatrixMultiplication(subtractMatrices(A12, A22), addMatrices(B21, B22));
int [][] C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
int [][] C12 = addMatrices(P3, P5);
int [][] C21 = addMatrices(P2, P4);
int [][] C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);
copySubArray(C11, result, 0 , 0);
copySubArray(C12, result, 0 , n/2);
copySubArray(C21, result, n/2, 0);
copySubArray(C22, result, n/2, n/2);
}
return result;
}
public static int [][] addMatrices(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
result[i][j] = A[i][j] + B[i][j];
return result;
}
public static int [][] subtractMatrices(int [][] A, int [][] B) {
int n = A.length;
int [][] result = new int[n][n];
for(int i=0; i<n; i++)
for(int j=0; j<n; j++)
result[i][j] = A[i][j] - B[i][j];
return result;
}
public static void divideArray(int[][] parent, int[][] child, int iB, int jB) {
for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
child[i1][j1] = parent[i2][j2];
}
public static void copySubArray(int[][] child, int[][] parent, int iB, int jB) {
for(int i1 = 0, i2=iB; i1<child.length; i1++, i2++)
for(int j1 = 0, j2=jB; j1<child.length; j1++, j2++)
parent[i2][j2] = child[i1][j1];
}
}
class strassen2{
public String getName () {
return "Strassen(cached)";
}
static int [][] p1;
static int [][] p2;
static int [][] p3;
static int [][] p4;
static int [][] p5;
static int [][] p6;
static int [][] p7;
static int [][] t0;
static int [][] t1;
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
final int n = c.length;
if (p1 == null || p1.length < n) {
p1 = new int[n/2][n-1];
p2 = new int[n/2][n-1];
p3 = new int[n/2][n-1];
p4 = new int[n/2][n-1];
p5 = new int[n/2][n-1];
p6 = new int[n/2][n-1];
p7 = new int[n/2][n-1];
t0 = new int[n/2][n-1];
t1 = new int[n/2][n-1];
}
mult(c, a, b, 0, 0, n, 0);
return c;
}
public static void mult (int[][] c, int[][] a, int[][] b, int i0, int j0, int n, int offs) {
if(n == 1) {
c[i0][j0] = a[i0][j0] * b[i0][j0];
} else {
final int nBy2 = n/2;
final int i1 = i0 + nBy2;
final int j1 = j0 + nBy2;
// offset applied to 'p' j index so recursive calls don't overwrite data
final int jp0 = offs;
final int jp1 = nBy2 + offs;
// P1 <- (A11 + A22)(B11 + B22)
// T0 <- (A11 + A22), T1 <- (B11 + B22), P1 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i0][j + j0] + b[i + i1][j + j1];
}
}
mult(p1, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P2 <- (A21 + A22)B11
// T0 <- (A21 + A22), T1 <- B11, P2 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j0] + a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i0][j + j0];
}
}
mult(p2, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P3 <- A11(B12 - B22)
// T0 <- A11, T1 <- (B12 - B22), P3 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0];
t1[i + i0][j + jp0] = b[i + i0][j + j1] - b[i + i1][j + j1];
}
}
mult(p3, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P4 <- A22(B21 - B11)
// T0 <- A22, T1 <- (B21 - B11), P4 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j0] - b[i + i0][j + j0];
}
}
mult(p4, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P5 <- (A11 + A12) B22
// T0 <- (A11 + A12), T1 <- B22, P5 <- T0*T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j0] + a[i + i0][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j1];
}
}
mult(p5, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P6 <- (A21 - A11)(B11 - B12)
// T0 <- (A21 - A11), T1 <- (B11 - B12), P6 <- T0 * T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i1][j + j0] - a[i + i0][j + j0];
t1[i + i0][j + jp0] = b[i + i0][j + j0] - b[i + i0][j + j1];
}
}
mult(p6, t0, t1, i0, jp0, nBy2, offs + nBy2);
// P7 <- (A12 - A22)(B21 + B22)
// T0 <- (A12 - A22), T1 <- (B21 + B22), P7 <- T0 * T1
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
t0[i + i0][j + jp0] = a[i + i0][j + j1] - a[i + i1][j + j1];
t1[i + i0][j + jp0] = b[i + i1][j + j0] + b[i + i1][j + j1];
}
}
mult(p7, t0, t1, i0, jp0, nBy2, offs + nBy2);
// combine
for (int i = 0; i < nBy2; ++i) {
for (int j = 0; j < nBy2; ++j) {
// C11 = P1 + P4 - P5 + P7;
c[i + i0][j + j0] = p1[i + i0][j + jp0] + p4[i + i0][j + jp0] - p5[i + i0][j + jp0] + p7[i + i0][j + jp0];
// C12 = P3 + P5;
c[i + i0][j + j1] = p3[i + i0][j + jp0] + p5[i + i0][j + jp0];
// C21 = P2 + P4;
c[i + i1][j + j0] = p2[i + i0][j + jp0] + p4[i + i0][j + jp0];
// C22 = P1 + P3 - P2 + P6;
c[i + i1][j + j1] = p1[i + i0][j + jp0] + p3[i + i0][j + jp0] - p2[i + i0][j + jp0] + p6[i + i0][j + jp0];
}
}
}
}
void dumpInternal () {
System.out.println("P1");
TestIntMatrixMultiplication.dumpMatrix(p1);
System.out.println("P2");
TestIntMatrixMultiplication.dumpMatrix(p2);
System.out.println("P3");
TestIntMatrixMultiplication.dumpMatrix(p3);
System.out.println("P4");
TestIntMatrixMultiplication.dumpMatrix(p4);
System.out.println("P5");
TestIntMatrixMultiplication.dumpMatrix(p5);
System.out.println("P6");
TestIntMatrixMultiplication.dumpMatrix(p6);
System.out.println("P7");
TestIntMatrixMultiplication.dumpMatrix(p7);
System.out.println("T0");
TestIntMatrixMultiplication.dumpMatrix(t0);
System.out.println("T1");
TestIntMatrixMultiplication.dumpMatrix(t1);
}
}
class Classical{
public String getName () {
return "classic";
}
public static int[][] mult (int[][] c, int[][] a, int[][] b) {
int n = a.length;
for(int i=0; i<n; i++) {
final int[] a_i = a[i];
final int[] c_i = c[i];
for(int j=0; j<n; j++) {
int sum = 0;
for(int k=0; k<n; k++) {
sum += a_i[k] * b[k][j];
}
c_i[j] = sum;
}
}
return c;
}
}
答案 0 :(得分:5)
我看到的问题:
1)你的Strassen乘法是一直动态分配内存。这会破坏性能。
2)你的Strassen乘法应该切换到小尺寸的常规乘法,而不是一直向下递归(尽管这种优化排序会使你的测试无效)。
3)你的矩阵大小可能太小而无法看出差异。
您应该与几种不同的尺码进行比较。也许256,512,1024,2048,4096,8192 ......然后绘制时间并观察趋势。如果它的所有幂都是2,那么您可能希望在对数刻度上使用矩阵大小。
Strassen对大N来说速度更快。多大程度上取决于实施。你为经典做的只是一个基本的实现,在现代机器上也不是最佳的。
答案 1 :(得分:2)
除了实现问题,我认为你误解了算法的性能。就像phkahler所说的那样,你对算法性能的期望有点偏差。分治算法适用于大型输入,因为它们递归地将问题分解为可以更快地解决的子问题。
但是,与此拆分操作相关的开销可能导致算法对于小型甚至中型输入运行(有时更慢)。通常,像Strassen这样的算法的理论分析将包括所谓的“断点”计算。这是输入大小,其中分裂的开销变得优于天真的技术。
您的代码需要检查在断点处切换到天真技术的输入大小。
答案 2 :(得分:1)
记下Strassen算法对2 x 2矩阵的作用。统计操作。这个数字绝对荒谬。使用Strassen的方法处理2x2矩阵是愚蠢的。对于3 x 3或4 x 4矩阵也是如此,可能还有很长的路要走。