如何在矩阵中引用子矩阵

时间:2010-11-29 04:43:27

标签: java

我有一个nxn矩阵A,其中n是2的幂。矩阵A被分成4个大小相等的子矩阵。如何在java中引用子矩阵A11,A12,A21和A22的矩阵?我正在尝试划分和征服矩阵乘法算法(Strassen)

            A11 | A12
   A -->    ---------
            A21 | A22

编辑:矩阵存储为整数数组:int [] []。

3 个答案:

答案 0 :(得分:3)

那么,如果ij是你的指数,则获得A11为i = 0 ..(n / 2)-1,j = 0 ..(n / 2) - 1。 然后,A12用于i = 0 ..(n / 2)-1和j = n / 2..n-1,依此类推。

要“引用”它们,你只需要一个“i_min,i_max,j_min,j_max”而不是运行从0到n-1的索引,从最小值到最大值运行它们。

答案 1 :(得分:1)

这是Strassen algorithm for matrix multiplication

的实现
import java.io.*;

public class MatrixMultiplication {

    public static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));

    public MatrixMultiplication() throws IOException {
        int n;
        int[][] a, b;

        System.out.print("Enter the number for rows/colums: ");
        n = Integer.parseInt(br.readLine());

        a = new int[n][n];
        b = new int[n][n];  

        System.out.print("\n\n\nEnter the values for the first matrix:\n\n");
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print("Enter the value for cell("+(i+1)+","+(j+1)+"): ");
                a[i][j] = Integer.parseInt(br.readLine());
            }
        }
        System.out.print("\n\n\nEnter the values for the second matrix:\n");
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print("Enter the value for cell ("+(i+1)+","+(j+1)+"): ");
                b[i][j] = Integer.parseInt(br.readLine());
            }
        }

        System.out.print("\n\nMatrix multiplication using standard method:\n");
        print(multiplyWithStandard(a, b));  

        System.out.print("\n\nMatrix multiplication using Strassen method:\n");
        print(multiplyWithStandard(a, b));  
    }

    public int[][] multiplyWithStandard(int[][] a, int[][] b) {
        int n = a.length;
        int[][] c = new int[n][n];

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                for (int k = 0; k < n; k++) {
                    c[i][j] += a[i][k] * b[k][j];
                } 
            }
        }
        return c;
    }

    public int[][] multiplyWithStrassen(int [][] A, int [][] B) {
        int n = A.length;
        int [][] result = new int[n][n];

        if (n == 1) {
            result[0][0] = A[0][0] * B[0][0];
        } else if ((n%2 != 0 ) && (n != 1)) {
            int[][] a1, b1, c1;
            int n1 = n+1;
            a1 = new int[n1][n1];
            b1 = new int[n1][n1];
            c1 = new int[n1][n1];

            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    a1[i][j] = A[i][j];
                    b1[i][j] = B[i][j];
                }
            } 
            c1 = multiplyWithStrassen(a1, b1);   
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    result[i][j] = c1[i][j];
                }
            }   
        } else {
            int [][] A11 = new int[n/2][n/2];
            int [][] A12 = new int[n/2][n/2];
            int [][] A21 = new int[n/2][n/2];
            int [][] A22 = new int[n/2][n/2];

            int [][] B11 = new int[n/2][n/2];
            int [][] B12 = new int[n/2][n/2];
            int [][] B21 = new int[n/2][n/2];
            int [][] B22 = new int[n/2][n/2];

            divideArray(A, A11, 0 , 0);
            divideArray(A, A12, 0 , n/2);
            divideArray(A, A21, n/2, 0);
            divideArray(A, A22, n/2, n/2);

            divideArray(B, B11, 0 , 0);
            divideArray(B, B12, 0 , n/2);
            divideArray(B, B21, n/2, 0);
            divideArray(B, B22, n/2, n/2);

            int [][] M1 = multiplyWithStrassen(add(A11, A22), add(B11, B22));
            int [][] M2 = multiplyWithStrassen(add(A21, A22), B11);
            int [][] M3 = multiplyWithStrassen(A11, subtract(B12, B22));
            int [][] M4 = multiplyWithStrassen(A22, subtract(B21, B11));
            int [][] M5 = multiplyWithStrassen(add(A11, A12), B22);
            int [][] M6 = multiplyWithStrassen(subtract(A21, A11), add(B11, B12));
            int [][] M7 = multiplyWithStrassen(subtract(A12, A22), add(B21, B22));

            int [][] C11 = add(subtract(add(M1, M4), M5), M7);
            int [][] C12 = add(M3, M5);
            int [][] C21 = add(M2, M4);
            int [][] C22 = add(subtract(add(M1, M3), M2), M6);

            copyArray(C11, result, 0 , 0);
            copyArray(C12, result, 0 , n/2);
            copyArray(C21, result, n/2, 0);
            copyArray(C22, result, n/2, n/2);
        }
        return result;
    }

    public int[][] add(int [][] A, int [][] B) {
        int n = A.length;  
        int [][] result = new int[n][n]; 

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++)
                result[i][j] = A[i][j] + B[i][j];
            }   
        return result;
    }

    public int[][] subtract(int [][] A, int [][] B) {
        int n = A.length;
        int [][] result = new int[n][n];

        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                result[i][j] = A[i][j] - B[i][j];
            }  
        }    
        return result;
    }

    private void divideArray(int[][] parent, int[][] child, int iB, int jB) {
        for (int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) {
            for (int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) {
                child[i1][j1] = parent[i2][j2];
            }
        }
    }

    private void copyArray(int[][] child, int[][] parent, int iB, int jB) {
        for(int i1 = 0, i2 = iB; i1 < child.length; i1++, i2++) {
            for(int j1 = 0, j2 = jB; j1 < child.length; j1++, j2++) {
                parent[i2][j2] = child[i1][j1];
            }
        }
    }

    public void print(int [][] array) {
        int n = array.length;  

        System.out.println();  
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(array[i][j] + "\t");
            }
            System.out.println();
        }
        System.out.println();
    }

    public static void main(String[] args) throws IOException {
        new MatrixMultiplication();
    }
} 

答案 2 :(得分:0)

我认为您必须决定是每次复制每个子矩阵的内容还是对寻址进行算术运算。您的问题意味着您的子矩阵是连续的而不是分裂的(如计算与未成年人和辅助因子的决定因素 - http://mathworld.wolfram.com/Determinant.html)。既然你没有说明你为什么要这样做,你已经遇到了什么性能,以及是否有递归较小矩阵我认为只有你可以决定复制的简单性或复杂性之间的平衡递归寻址。但我希望已有图书馆,我会检查http://commons.apache.org/math/