用分治法将两个矩阵相乘

时间:2019-05-09 17:34:27

标签: java matrix

所以我需要将两个矩阵相乘并除以征服,但我需要使用以下算法:

enter image description here

  

我认为问题在于合并部分

我遇到的问题是,我不知道如何合并结果,实际上我认为我正确理解了这些术语,但是我不知道如何将它们与该算法结合起来。

我的想法是首先检查基本情况,即是否让MatrixA像只有一行的向量,而让MatrixB像只有一行的向量。 如果发生这种情况,我需要遍历两个向量并进行乘法和求和,然后将结果返回到数组中。

另一部分是一般情况,这意味着矩阵比我想要的要大,因此,如算法所示,我将通过拆分来缩小它们,并通过四个递归调用来实现。

问题是,结合方面。如果两个Matrix的大小相同,则效果很好,但是如果一个与另一个大小不同,则会崩溃

这是我的代码:

    package MatrixMultiplicationV1;

public class MatrixMultiplicationV1
{
public static void main(String[] args)
{
    //int[][] matrixA = {{-3,4,1,6,9},{2,-1,-4,0,-2},{1,1,4,-2,3}};
    //int[][] matrixB = {{1,3,-2,-5},{0,2,1,2},{-4,-5,-3,-2},{1,3,4,2},{0,0,3,0}};

    int[][] matrixA = {{2,3},{1,4}};
    int[][] matrixB = {{5,7},{8,9}};

    int[][] result = divideAndConquer(matrixA,matrixB,0,matrixA.length-1,0,
            matrixA[0].length-1,0,matrixB.length-1,0, matrixB[0].length-1);
    print(result);
}

private static int[][] divideAndConquer(int[][] matrixA, int[][] matrixB, int beginRowsA, int endRowsA,
                                        int beginColumnsA, int endColumnsA, int beginRowsB, int endRowsB,
                                        int beginColumnsB, int endColumnsB)
{
    //Base case
    if((beginRowsA == endRowsA) && (beginColumnsB == endColumnsB))
    {
        int[][] matrixC = new int[1][1]; //Rows Matrix A * Columns Matrix B
        int val = 0;
        for(int i=0; i<=endColumnsA; i++)
        {

            val += matrixA[beginRowsA][i]*matrixB[i][beginColumnsB];
        }
        matrixC[0][0] = val;
        return matrixC;
    }
    else //General Case
    {
        int middleRowsMatrixA = (beginRowsA+endRowsA)/2;
        int middleColumnsMatrixB = (beginColumnsB+endColumnsB)/2;

        int[][] matrixA1B1 =divideAndConquer(matrixA,matrixB,beginRowsA,middleRowsMatrixA,beginColumnsA,endColumnsA,beginRowsB,
                endRowsB,beginColumnsB,middleColumnsMatrixB);
        int[][] matrixA1B2 = divideAndConquer(matrixA,matrixB,beginRowsA,middleRowsMatrixA,beginColumnsA,endColumnsA,
                beginRowsB,endRowsB,middleColumnsMatrixB+1,endColumnsB);
        int[][] matrixA2B1 = divideAndConquer(matrixA,matrixB,middleRowsMatrixA+1,endRowsA,beginColumnsA,
                endColumnsA,beginRowsB,endRowsB,beginColumnsB,middleColumnsMatrixB);
        int[][] matrixA2B2 = divideAndConquer(matrixA,matrixB,middleRowsMatrixA+1,endRowsA,beginColumnsA,
                endColumnsA,beginRowsB,endRowsB,middleColumnsMatrixB+1,endColumnsB);
        int[][] matrixC = new int[matrixA.length][matrixB[0].length];
        return combine(matrixA1B1,matrixA1B2,matrixA2B1,matrixA2B2,matrixC);

    }
}

private static int[][] combine(int[][] matrixA1B1, int[][] matrixA1B2, int[][] matrixA2B1, int[][] matrixA2B2,
                            int[][] matrixC)
{
    matrixC[0][0] = matrixA1B1[0][0];
    matrixC[1][0] = matrixA2B1[0][0];
    matrixC[0][1] = matrixA1B2[0][0];
    matrixC[1][1] = matrixA2B2[0][0];

    return matrixC;
}


private static void print(int[][] matrix)
{
    for(int i=0; i<matrix.length; i++)
    {
        for(int j=0; j<matrix[0].length; j++)
        {
            System.out.print(matrix[i][j]+"  ");
        }
        System.out.println();
    }
    System.out.println();
}

}

0 个答案:

没有答案