是否有一个java库实现了Strassen的矩阵求逆算法?

时间:2011-04-21 01:52:01

标签: java algebra linear

我正在研究一个比较各种矩阵反演算法(Gauss-Jordan,Strassen和Coppersmith-Winograd)的项目。我找到了gauss-jordan实现,但我还没有找到任何Strassens矩阵求逆算法的实现。大量的谷歌时间没有发现任何东西。我查看了Jama库,Colt库和Apache Math库,但这些都使用了Gauss Elimination或LU Decomp。有谁知道实现Strassen矩阵求逆算法的java库?

对于那些不熟悉算法的人,请查看此page的底部。此外,Coppersmith-Winograd维基百科页面还简要介绍了Strassen的算法。

由于

1 个答案:

答案 0 :(得分:3)

public class Strassen 
{

    public static double[][] multiply(double[][] A, double[][] B)
    {
        try
        {
            checkInputStrassen(A,B);
        }
        catch (RuntimeException e)
        {
            throw e;
        }
        return strassenRecursive(A,B);
    }

    private static double[][] reconstructAnswer(double[][] r, double[][] s,
            double[][] t, double[][] u) 
    {
        int n = r.length*2;
        double[][] C = new double[n][n];
        copyBack(C,r,0,0);
        copyBack(C,s,0,n/2);
        copyBack(C,t,n/2,0);
        copyBack(C,u,n/2,n/2);
        return C;
    }

    private static void copyBack(double[][] C, double[][] r, int x, int y) 
    {
        for (int i=0; i<r.length; i++)
        {
            for (int j=0; j<r.length; j++)
            {
                C[x+i][y+j] = r[i][j];
            }
        }
    }

    private static void copy(double[][] a, double[][] A, int x, int y) 
    {
        for (int i=0; i<a.length; i++)
        {
            for (int j=0; j<a.length; j++)
            {
                a[i][j] = A[x+i][y+j]; 
            }
        }
    }

    private static double[][] strassenRecursive(double[][] A, double[][] B) 
    {
        int n = A.length;
        if (n==1)
        {
            double[][] C = new double[1][1];
            C[0][0]=A[0][0]*B[0][0];
            return C;
        }

        double[][] r,s,t,u, a,b,c,d,e,f,g,h, P1,P2,P3,P4,P5,P6,P7;
        r = new double[n/2][n/2];       s = new double[n/2][n/2];       t = new double[n/2][n/2];       
        u = new double[n/2][n/2];       a = new double[n/2][n/2];       b = new double[n/2][n/2];       
        c = new double[n/2][n/2];       d = new double[n/2][n/2];       e = new double[n/2][n/2];       
        f = new double[n/2][n/2];       g = new double[n/2][n/2];       h = new double[n/2][n/2];
        P1 = new double[n/2][n/2];      P2 = new double[n/2][n/2];      P3 = new double[n/2][n/2];      
        P4 = new double[n/2][n/2];      P5 = new double[n/2][n/2];      P6 = new double[n/2][n/2];      
        P7 = new double[n/2][n/2];      
        copy(a,A,0,0);
        copy(b,A,0,n/2);
        copy(c,A,n/2,0);
        copy(d,A,n/2,n/2);
        copy(e,B,0,0);
        copy(f,B,0,n/2);
        copy(g,B,n/2,0);
        copy(h,B,n/2,n/2);

        P1= strassenRecursive(a, add(f,h,-1)); // P1 = a(f-h) = af-ah
        P2= strassenRecursive(add(a,b,1), h); // P2 = (a+b)h = ah+bh
        P3= strassenRecursive(add(c,d,1), e); // P3 = (c+d)e = ce+de
        P4= strassenRecursive(d, add(g,e,-1)); // P4 = d(g-e) = dg-de
        P5= strassenRecursive(add(a,d,1), add(e,h,1)); // P5 = (a+d)(e+h)=ae+de+ah+dh
        P6= strassenRecursive(add(b,d,-1), add(g,h,1)); // P6 = (b-d)(g+h)=bg-dg+bh-dh
        P7= strassenRecursive(add(a,c,-1), add(e,f,1)); // P7 = (a-c)(e+f)=ae-ce+af-cf

        r = add(add(P5,P4,1),add(P2,P6,-1),-1); // r = P5+P4-P2+P6 = ae+bg
        s = add(P1,P2,1); // s = P1+P2 = af+bh
        t = add(P3,P4,1); // t = P3+P4 = ce+dg
        u = add(add(P5,P1,1),add(P3,P7,1),-1); //u = P5+P1-P3-P7 = cf+dh
        return reconstructAnswer(r,s,t,u);
    }

    private static double[][] add(double[][] A, double[][] B, int signofB)
    {
        int n = A.length;
        double[][] C = new double[n][n];
        for (int i=0; i<n; i++)
        {
            for (int j=0; j<n; j++)
            {
                C[i][j] = A[i][j] + signofB*B[i][j];
            }
        }
        return C;

    }

    private static void checkInputStrassen(double[][] A, double[][] B) 
    {
        int p = A.length;
        if (p==0)
        {
            throw new IllegalArgumentException("Null matrix");
        }
        int n=p;
        while (n>1)
        {
            if (n%2 != 0)
            {
                throw new IllegalArgumentException("Non power of 2 matrix");
            }
            n/=2;
        }

        int q = A[0].length;
        if (q==0)
        {
            throw new IllegalArgumentException("Null matrix");
        }

        if (q!=p)
        {
            throw new IllegalArgumentException("Nonsquare Matrix");
        }

        for (int i=1; i<p; i++)
        {
            if (A[i].length != q)
            {
                throw new IllegalArgumentException("Inconsistent matrix");
            }
        }
        if (B.length != q)
        {
            throw new IllegalArgumentException("Inconsistent dimensions");
        }

        int r = B[0].length;
        if (r!=p)
        {
            throw new IllegalArgumentException("Nonsquare Matrix");
        }
        for (int i=1; i<q; i++)
        {
            if (B[i].length != r)
            {
                throw new IllegalArgumentException("Inconsistent matrix");
            }
        }
    }
}

来自:http://www.cs.huji.ac.il/~omrif01/Strassen