ikj算法的优化(矩阵乘法)

时间:2017-05-08 00:24:11

标签: java algorithm matrix optimization

创建自定义矩阵类我使用ikj算法实现乘法,现在我正在尝试优化它。问题是应该更好的版本大约慢5倍,我无法理解为什么。

这是具有“基本”算法的Matrix类:

class Matrix {

    private double[][] m;           // matrix
    private int rows;
    private int cols;

   // other stuff...


    // does some checks and returns requested matrix value
    // I know this will slow down computation, but it's not the relevant part
    public double get(int row, int col) {
        if (row >= rows || col >= cols)
            throw new IndexOutOfBoundsException(); // to catch block
        else
            return m[startRow + row][startCol + col];
    }


    public Matrix multiply(Matrix other) {
        int n = rows;
        int m = cols;
        int p = other.cols;

        double[][] prod = new double[n][p];

        for (int i = 0; i < n; i++)
            for (int k = 0; k < m; k++)
                for (int j = 0; j < p; j++)
                    prod[i][j] += get(i,k) * other.get(k,j);

        return new Matrix(prod);
    }
}

这是修改后的算法:

public Matrix multiplyOpt(Matrix other) {
    int n = rows;
    int m = cols;
    int p = other.cols;

    double[][] prod = new double[n][p];

    for (int i = 0; i < n; i++) {
        for (int k = 0; k < m; k++) {
            double aik = get(i,k);
            for (int j = 0; j < p; j++) {
                prod[i][j] += aik * other.get(k,j);
            }
        }
    }

    return new Matrix(prod);
}

我的意思是,在循环外移动调用它将被称为n x m次而不是n x m x p。

这些是随机矩阵乘法的结果(永远不会抛出异常):

multiply time = 0.599s
multiplyOpt time = 3.041s

为什么这种变化使它变慢而不是更快?

编辑

通过以下方式获得计时:

double[][] m1 = new double[1000][750];
double[][] m2 = new double[750][1250];

for (int i = 0; i < m1.length; i++)
    for (int j = 0; j < m1[0].length; j++)
        m1[i][j] = new Double(Math.random());

for (int i = 0; i < m2.length; i++)
    for (int j = 0; j < m2[0].length; j++)
        m2[i][j] = new Double(Math.random());

Matrix a = new Matrix(m1);
Matrix b = new Matrix(m2);

long start = System.currentTimeMillis();
Matrix c = a.multiply(b);
long stop = System.currentTimeMillis();
double time = (stop - start) / 1000.0;
System.out.println("multiply time = "+time);

start = System.currentTimeMillis();
c = a.multiplyOpt(b);
stop = System.currentTimeMillis();
time = (stop - start) / 1000.0;
System.out.println("multiplyOpt time = "+time);

0 个答案:

没有答案