是否有更快的方法来计算子矩阵的总和?

时间:2016-01-05 10:05:52

标签: java algorithm performance submatrix

我必须从方矩阵的几个子矩阵中计算总和。我有这种格式的输入:

8 8 // matrix is 8 x 8 and there are going to be 8 submatrixes
-5 -4 -6 -2 1 -8 6 -1 //first row of the matrix
-9 7 -3 -7 2 0 -6 -2  // second row of the matrix etc.
6 -8 2 6 -7 0 3 -5
-1 3 9 4 -7 0 -5 -3
-8 0 0 -6 -5 -7 -7 0
2 7 6 2 -6 6 5 0
-1 -7 8 -7 6 7 -2 1
-8 -3 -5 2 -5 4 -1 -2
0 2 3 6  //upperRow, leftColumn, lowerRow, rightColumn of submatrix
2 6 4 6
0 7 1 7
7 4 7 4
1 7 7 7
2 7 6 7
4 5 6 5
6 2 7 5

我需要计算所有子矩阵的总数(尤其)。我的代码工作正常(编译,运行,给出正确的结果),但我的方法public static int total(int[][] M, int upperRow, int leftColumn, int lowerRow, int rightColumn)杀死了所有的性能(我测量了它)。确切地说,内部的循环会杀死它。

是否有更快的方法(更有效)来计算子矩阵的总数?

import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;

public class Main {

    public static void display(int[][] a2d) {
        for (int[] a : a2d) {
            for (int val : a) {
                System.out.print(val + " ");
            }
            System.out.println();
        }
    }

    public static int total(int[][] M, int upperRow, int leftColumn, int lowerRow, int rightColumn) {
        int rows = lowerRow - upperRow + 1;
        int cols = rightColumn - leftColumn + 1;
        int sum = 0;
        int columnToCopyFrom = leftColumn;
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                sum += M[upperRow][columnToCopyFrom];
                columnToCopyFrom++;
            }
            columnToCopyFrom = leftColumn;
            upperRow++;
        }
        return sum;
    }



    public static void main(String[] args) throws Exception {
        //BufferedReader br = new BufferedReader(new FileReader("input3"));
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] firstLine = br.readLine().split(" ");
        int n = Integer.parseInt(firstLine[0]);
        int k = Integer.parseInt(firstLine[1]);
        int[][] M = new int[n][n];
        for (int i = 0; i < n; i++) {
            String[] rowContents = br.readLine().split(" ");
            for (int j = 0; j < rowContents.length; j++) {
                M[i][j] = Integer.parseInt(rowContents[j]);
            }
        }

        int avgSum = 0;
        int total;
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for (int i = 0; i < k; i++) {
            String[] rowContents = br.readLine().split(" ");
            int upperRow = Integer.parseInt(rowContents[0]);
            int leftColumn = Integer.parseInt(rowContents[1]);
            int lowerRow = Integer.parseInt(rowContents[2]);
            int rightColumn = Integer.parseInt(rowContents[3]);
            total = total(M, upperRow, leftColumn, lowerRow, rightColumn);
            //srednia
            avgSum += total;
            //klasy abstrakcji
            if (!map.containsKey(total)) {
                map.put(total, 1);
            } else {
                map.put(total, map.get(total) + 1);
            }
            //display(cutOut);
        }
        int maxCount = 0;
        int maxAbstractionClass = Integer.MIN_VALUE;
        ArrayList<Entry> list = new ArrayList<Entry>();
        for (Entry<Integer, Integer> entry : map.entrySet()) {
            if (entry.getValue() > maxCount || (entry.getValue() == maxCount && entry.getKey() > maxAbstractionClass)) {
                maxAbstractionClass = entry.getKey();
                maxCount = entry.getValue();
            }

        }
        for (Entry<Integer, Integer> entry : map.entrySet()) {
        if(maxCount==entry.getValue()){
            list.add(entry);
        }
        }
        System.out.print(map.size() + " " + list.size() + " " + avgSum / k);

    }
}

通过记住HashMap

中的总数,我减少了几毫秒
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map.Entry;

public class Main {
    public static HashMap<FourNumbers, Integer> mapOfTotals = new HashMap<FourNumbers, Integer>();
    public static void display(int[][] a2d) {
        for (int[] a : a2d) {
            for (int val : a) {
                System.out.print(val + " ");
            }
            System.out.println();
        }
    }

    public static int total(int[][] M, int upperRow, int leftColumn, int lowerRow, int rightColumn) {
        FourNumbers fourNumbers = new FourNumbers(upperRow, leftColumn, lowerRow, rightColumn);
        if(mapOfTotals.containsKey(fourNumbers)){
            return mapOfTotals.get(fourNumbers);
        }
        int rows = lowerRow - upperRow + 1;
        int cols = rightColumn - leftColumn + 1;
        int sum = 0;
        int columnToCopyFrom = leftColumn;
        for (int i = 0; i < rows; i++) {
            for (int j = 0; j < cols; j++) {
                sum += M[upperRow][columnToCopyFrom];
                columnToCopyFrom++;
            }
            columnToCopyFrom = leftColumn;
            upperRow++;
        }
        mapOfTotals.put(fourNumbers, sum);
        return sum;
    }

    public static void main(String[] args) throws Exception {
        //BufferedReader br = new BufferedReader(new FileReader("input3"));
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        String[] firstLine = br.readLine().split(" ");
        int n = Integer.parseInt(firstLine[0]);
        int k = Integer.parseInt(firstLine[1]);
        int[][] M = new int[n][n];
        for (int i = 0; i < n; i++) {
            String[] rowContents = br.readLine().split(" ");
            for (int j = 0; j < rowContents.length; j++) {
                M[i][j] = Integer.parseInt(rowContents[j]);
            }
        }

        int avgSum = 0;
        int total;
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();//KLUCZEM jest nazwa klasy abstrakcji(suma), wartoscia jest liczba wystapien tej klasy abstrakcji
        for (int i = 0; i < k; i++) {
            String[] rowContents = br.readLine().split(" ");
            int upperRow = Integer.parseInt(rowContents[0]);
            int leftColumn = Integer.parseInt(rowContents[1]);
            int lowerRow = Integer.parseInt(rowContents[2]);
            int rightColumn = Integer.parseInt(rowContents[3]);
            total = total(M, upperRow, leftColumn, lowerRow, rightColumn);
            //srednia
            avgSum += total;
            //klasy abstrakcji
            if (!map.containsKey(total)) {
                map.put(total, 1);
            } else {
                map.put(total, map.get(total) + 1);
            }
            //display(cutOut);
        }
        int maxCount = 0;
        int maxAbstractionClass = Integer.MIN_VALUE;
        ArrayList<Entry> list = new ArrayList<Entry>();
        for (Entry<Integer, Integer> entry : map.entrySet()) {
            if (entry.getValue() > maxCount || (entry.getValue() == maxCount && entry.getKey() > maxAbstractionClass)) {
                maxAbstractionClass = entry.getKey();
                maxCount = entry.getValue();
            }

        }
        for (Entry<Integer, Integer> entry : map.entrySet()) {
            if (maxCount == entry.getValue()) {
                list.add(entry);
            }
        }
        System.out.print(map.size() + " " + list.size() + " " + avgSum / k);

    }
}

class FourNumbers {
    int upperRow, leftColumn, lowerRow, rightColumn;

    public FourNumbers(int upperRow, int leftColumn, int lowerRow, int rightColumn) {
        this.upperRow = upperRow;
        this.leftColumn = leftColumn;
        this.lowerRow = lowerRow;
        this.rightColumn = rightColumn;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + leftColumn;
        result = prime * result + lowerRow;
        result = prime * result + rightColumn;
        result = prime * result + upperRow;
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        FourNumbers other = (FourNumbers) obj;
        if (leftColumn != other.leftColumn)
            return false;
        if (lowerRow != other.lowerRow)
            return false;
        if (rightColumn != other.rightColumn)
            return false;
        if (upperRow != other.upperRow)
            return false;
        return true;
    }

}

1 个答案:

答案 0 :(得分:0)

如果你必须计算4x4大小的所有子矩阵,你可以这样做

    for (int i = 0; i < 4; i++) {
        for (int j = 0; j < 4; j++) {
            sum += M[i][j];
        }
    }
    stored[0][0]=sum;

然后对于所有后续矩阵,在末尾添加一列或一行,并减去从前面留下子矩阵的行或列

            sum = stored[0][0];

        for (int j = 0; j < 4; j++) {
            sum -= M[0][j];
        }
        for (int j = 0; j < 4; j++) {
            sum -= M[j][0];
        }

        for (int j = 0; j < 4; j++) {
            sum += M[4][j];
        }
        for (int j = 0; j < 4; j++) {
            sum += M[j][4];
        }

此外,这是多余的 - 您不需要添加额外的变量只需循环的变量就足够了

   for (int i = 0; i < rows; i++) {
        for (int j = 0; j < cols; j++) {
            sum += M[upperRow][columnToCopyFrom];
            columnToCopyFrom++;
        }
        columnToCopyFrom = leftColumn;
        upperRow++;
    }

   for (int i = upperRow; i <= loweRow; i++) {
        for (int j = leftColumn; j <= rightColumn; j++) {
            sum += M[i][j];
        }
    }