用Java连接矩阵乘法

时间:2011-03-29 12:47:43

标签: java join matrix fork multiplication

我正在对Java 7中的fork / join框架进行一些性能研究。为了改进测试结果,我想在测试期间使用不同的递归算法。其中一个是乘法矩阵。

我从Doug Lea的网站()下载了以下示例:

public class MatrixMultiply {

  static final int DEFAULT_GRANULARITY = 16;

  /** The quadrant size at which to stop recursing down
   * and instead directly multiply the matrices.
   * Must be a power of two. Minimum value is 2.
   **/
  static int granularity = DEFAULT_GRANULARITY;

  public static void main(String[] args) {

    final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16";

    try {
      int procs;
      int n;
      try {
        procs = Integer.parseInt(args[0]);
        n = Integer.parseInt(args[1]);
        if (args.length > 2) granularity = Integer.parseInt(args[2]);
      }

      catch (Exception e) {
        System.out.println(usage);
        return;
      }

      if ( ((n & (n - 1)) != 0) || 
           ((granularity & (granularity - 1)) != 0) ||
           granularity < 2) {
        System.out.println(usage);
        return;
      }

      float[][] a = new float[n][n];
      float[][] b = new float[n][n];
      float[][] c = new float[n][n];
      init(a, b, n);

      FJTaskRunnerGroup g = new FJTaskRunnerGroup(procs);
      g.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n));
      g.stats();

      // check(c, n);
    }
    catch (InterruptedException ex) {}
  }


  // To simplify checking, fill with all 1's. Answer should be all n's.
  static void init(float[][] a, float[][] b, int n) {
    for (int i = 0; i < n; ++i) {
      for (int j = 0; j < n; ++j) {
        a[i][j] = 1.0F;
        b[i][j] = 1.0F;
      }
    }
  }

  static void check(float[][] c, int n) {
    for (int i = 0; i < n; i++ ) {
      for (int j = 0; j < n; j++ ) {
        if (c[i][j] != n) {
          throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]);
        }
      }
    }
  }

  /** 
   * Multiply matrices AxB by dividing into quadrants, using algorithm:
   * <pre>
   *      A      x      B                             
   *
   *  A11 | A12     B11 | B12     A11*B11 | A11*B12     A12*B21 | A12*B22 
   * |----+----| x |----+----| = |--------+--------| + |---------+-------|
   *  A21 | A22     B21 | B21     A21*B11 | A21*B21     A22*B21 | A22*B22 
   * </pre>
   */


  static class Multiplier extends FJTask {
    final float[][] A;   // Matrix A
    final int aRow;      // first row    of current quadrant of A
    final int aCol;      // first column of current quadrant of A

    final float[][] B;   // Similarly for B
    final int bRow;
    final int bCol;

    final float[][] C;   // Similarly for result matrix C
    final int cRow;
    final int cCol;

    final int size;      // number of elements in current quadrant

    Multiplier(float[][] A, int aRow, int aCol,
               float[][] B, int bRow, int bCol,
               float[][] C, int cRow, int cCol,
               int size) {
      this.A = A; this.aRow = aRow; this.aCol = aCol;
      this.B = B; this.bRow = bRow; this.bCol = bCol;
      this.C = C; this.cRow = cRow; this.cCol = cCol;
      this.size = size;
    }

    public void run() {

      if (size <= granularity) {
        multiplyStride2();
      }

      else {
        int h = size / 2;

        coInvoke(new FJTask[] {
          seq(new Multiplier(A, aRow,   aCol,    // A11
                             B, bRow,   bCol,    // B11
                             C, cRow,   cCol,    // C11
                             h),
              new Multiplier(A, aRow,   aCol+h,  // A12
                             B, bRow+h, bCol,    // B21
                             C, cRow,   cCol,    // C11
                             h)),

          seq(new Multiplier(A, aRow,   aCol,    // A11
                             B, bRow,   bCol+h,  // B12
                             C, cRow,   cCol+h,  // C12
                             h),
              new Multiplier(A, aRow,   aCol+h,  // A12
                             B, bRow+h, bCol+h,  // B22
                             C, cRow,   cCol+h,  // C12
                             h)),

          seq(new Multiplier(A, aRow+h, aCol,    // A21
                             B, bRow,   bCol,    // B11
                             C, cRow+h, cCol,    // C21
                             h),
              new Multiplier(A, aRow+h, aCol+h,  // A22
                             B, bRow+h, bCol,    // B21
                             C, cRow+h, cCol,    // C21
                             h)),

          seq(new Multiplier(A, aRow+h, aCol,    // A21
                             B, bRow,   bCol+h,  // B12
                             C, cRow+h, cCol+h,  // C22
                             h),
              new Multiplier(A, aRow+h, aCol+h,  // A22
                             B, bRow+h, bCol+h,  // B22
                             C, cRow+h, cCol+h,  // C22
                             h))
        });
      }
    }

    /** 
     * Version of matrix multiplication that steps 2 rows and columns
     * at a time. Adapted from Cilk demos.
     * Note that the results are added into C, not just set into C.
     * This works well here because Java array elements
     * are created with all zero values.
     **/

    void multiplyStride2() {
      for (int j = 0; j < size; j+=2) {
        for (int i = 0; i < size; i +=2) {

          float[] a0 = A[aRow+i];
          float[] a1 = A[aRow+i+1];

          float s00 = 0.0F; 
          float s01 = 0.0F; 
          float s10 = 0.0F; 
          float s11 = 0.0F; 

          for (int k = 0; k < size; k+=2) {

            float[] b0 = B[bRow+k];

            s00 += a0[aCol+k]   * b0[bCol+j];
            s10 += a1[aCol+k]   * b0[bCol+j];
            s01 += a0[aCol+k]   * b0[bCol+j+1];
            s11 += a1[aCol+k]   * b0[bCol+j+1];

            float[] b1 = B[bRow+k+1];

            s00 += a0[aCol+k+1] * b1[bCol+j];
            s10 += a1[aCol+k+1] * b1[bCol+j];
            s01 += a0[aCol+k+1] * b1[bCol+j+1];
            s11 += a1[aCol+k+1] * b1[bCol+j+1];
          }

          C[cRow+i]  [cCol+j]   += s00;
          C[cRow+i]  [cCol+j+1] += s01;
          C[cRow+i+1][cCol+j]   += s10;
          C[cRow+i+1][cCol+j+1] += s11;
        }
      }
    }

  }

}

此代码是为旧版本的fork / join框架编写的。所以我必须重写它。我重写的代码实现了我自己的界面,如下所示:

public class Java7MatrixMultiply implements Algorithm { 
    private static final int SIZE = 32;
    private static final int THRESHOLD = 8;

    private float[][] a = new float[SIZE][SIZE];
    private float[][] b = new float[SIZE][SIZE];
    private float[][] c = new float[SIZE][SIZE];

    ForkJoinPool forkJoinPool;

    @Override
    public void initialize() {
        init(a, b, SIZE);
    }

    @Override
    public void execute() {
        MatrixMultiplyTask mainTask = new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE);
        forkJoinPool = new ForkJoinPool();
        forkJoinPool.invoke(mainTask);

        System.out.println("Terminated!");
    }

    @Override
    public void printResult() { 
        check(c, SIZE);

        for (int i = 0; i < SIZE; i++) {
            for (int j = 0; j < SIZE; j++) {
                System.out.print(c[i][j] + " ");
            }

            System.out.println();
        }
    }

    // To simplify checking, fill with all 1's. Answer should be all n's.
    static void init(float[][] a, float[][] b, int n) {
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                a[i][j] = 1.0F;
                b[i][j] = 1.0F;
            }
        }
    }

    static void check(float[][] c, int n) {
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (c[i][j] != n) {
                    //throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                    System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]);
                }
            }
        }
    }

    private class MatrixMultiplyTask extends RecursiveAction {
        private final float[][] A; // Matrix A
        private final int aRow; // first row of current quadrant of A
        private final int aCol; // first column of current quadrant of A

        private final float[][] B; // Similarly for B
        private final int bRow;
        private final int bCol;

        private final float[][] C; // Similarly for result matrix C
        private final int cRow;
        private final int cCol;

        private final int size;

        MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B,
                int bRow, int bCol, float[][] C, int cRow, int cCol, int size) {
            this.A = A;
            this.aRow = aRow;
            this.aCol = aCol;
            this.B = B;
            this.bRow = bRow;
            this.bCol = bCol;
            this.C = C;
            this.cRow = cRow;
            this.cCol = cCol;
            this.size = size;
        }

        @Override
        protected void compute() {      
            if (size <= THRESHOLD) {
                multiplyStride2();
            } else {

                int h = size / 2;               

                invokeAll(new MatrixMultiplyTask[] {
                        new MatrixMultiplyTask(A, aRow, aCol, // A11
                                B, bRow, bCol, // B11
                                C, cRow, cCol, // C11
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol, // B21
                                C, cRow, cCol, // C11
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol, // A11
                                B, bRow, bCol + h, // B12
                                C, cRow, cCol + h, // C12
                                h),

                        new MatrixMultiplyTask(A, aRow, aCol + h, // A12
                                B, bRow + h, bCol + h, // B22
                                C, cRow, cCol + h, // C12
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol, // A21
                                B, bRow, bCol, // B11
                                C, cRow + h, cCol, // C21
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol, // B21
                                C, cRow + h, cCol, // C21
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol, // A21
                                B, bRow, bCol + h, // B12
                                C, cRow + h, cCol + h, // C22
                                h),

                        new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22
                                B, bRow + h, bCol + h, // B22
                                C, cRow + h, cCol + h, // C22
                                h) });

            }
        }

        /**
         * Version of matrix multiplication that steps 2 rows and columns at a
         * time. Adapted from Cilk demos. Note that the results are added into
         * C, not just set into C. This works well here because Java array
         * elements are created with all zero values.
         **/

        void multiplyStride2() {
            for (int j = 0; j < size; j += 2) {
                for (int i = 0; i < size; i += 2) {

                    float[] a0 = A[aRow + i];
                    float[] a1 = A[aRow + i + 1];

                    float s00 = 0.0F;
                    float s01 = 0.0F;
                    float s10 = 0.0F;
                    float s11 = 0.0F;

                    for (int k = 0; k < size; k += 2) {

                        float[] b0 = B[bRow + k];

                        s00 += a0[aCol + k] * b0[bCol + j];
                        s10 += a1[aCol + k] * b0[bCol + j];
                        s01 += a0[aCol + k] * b0[bCol + j + 1];
                        s11 += a1[aCol + k] * b0[bCol + j + 1];

                        float[] b1 = B[bRow + k + 1];

                        s00 += a0[aCol + k + 1] * b1[bCol + j];
                        s10 += a1[aCol + k + 1] * b1[bCol + j];
                        s01 += a0[aCol + k + 1] * b1[bCol + j + 1];
                        s11 += a1[aCol + k + 1] * b1[bCol + j + 1];
                    }

                    C[cRow + i][cCol + j] += s00;
                    C[cRow + i][cCol + j + 1] += s01;
                    C[cRow + i + 1][cCol + j] += s10;
                    C[cRow + i + 1][cCol + j + 1] += s11;
                }
            }
        }
    }
}

有时我的计算无法通​​过检查。 Matrix的某些字段具有预期的不同值。这些不一致是随机的,并不总是发生。我怀疑计算方法出了问题,因为我不得不重写使用Seq类的部分。与invokeAll()方法不同,Seq klass按顺序执行任务。在当前版本的fork / join框架中,该类不再存在。我对矩阵乘法算法不太熟悉,所以很难看出出了什么问题。有什么建议吗?

2 个答案:

答案 0 :(得分:1)

您正在C[cRow + i][cCol + j] += s00;等累积结果。这不是线程安全操作,因此您必须同步行或确保只有一个任务更新单元格。如果没有这个,你会发现随机单元设置不正确。

我会检查你得到正确答案,并发为1.

BTW:float可能不是这里的最佳选择。它具有相当低的精度位数和重型矩阵运算(我假设你正在做或者使用多个线程没有多大意义)舍入误差可能会消耗大部分或全部精度。我建议改为考虑double

e.g。 float有大约7位精度,一条经验法则是误差与计算次数成正比。因此,对于1K x 1K矩阵,您可能还有4位精度。对于10K x 10K,您可能只有三个。 double有16位精度,这意味着在10K x 10K变换后你可能有12位数的精度。

答案 1 :(得分:0)

正如您已经注意到的,顺序执行属于同一象限的子任务对于此算法很重要。因此,您需要实现自己的seq()函数,例如,如下所示,并将其用作原始代码:

public ForkJoinTask<?> seq(final ForkJoinTask<?> a, final ForkJoinTask<?> b) {
    return adapt(new Runnable() {
        public void run() {
            a.invoke();
            b.invoke();
        }
    });
}