
时间:2018-11-01 02:48:30

标签: java multithreading join fork matrix-multiplication

我有这项任务来运用我所了解的关于fork和join的知识。我听说有传言说,n = 11可以实现60倍的优化,以解决此问题。我只能认为问题是基于Matrix中提供的递归算法将乘法分解为4个较小的矩阵,因此我只复制了算法并将其转换为MatrixMultiplication的构造函数,如{ {1}}方法。但是,我仅成功实现了compute的20倍优化。我不知道我无法实现所需的优化是否与我使用的算法或fork和join的位置有关。当我将某些调用的构造函数更改为n = 11而不是computefork时,我没有寻求任何改进。有人可以给我一个提示,告诉我我什至可以实现60倍的改进优化。


enter image description here

enter image description here




import java.util.concurrent.RecursiveTask;

class MatrixMultiplication extends RecursiveTask<Matrix> {

  /** The fork threshold. */
  private static final int FORK_THRESHOLD = 128;

  /** The first matrix to multiply with. */
  private Matrix m1;

  /** The second matrix to multiply with. */
  private Matrix m2;

  /** The starting row of m1. */
  private int m1Row;

  /** The starting col of m1. */
  private int m1Col;

  /** The starting row of m2. */
  private int m2Row;

  /** The starting col of m2. */
  private int m2Col;

   * The dimension of the input (sub)-matrices and the size of the output
   * matrix.
  private int dimension;

   * A constructor for the Matrix Multiplication class.
   * @param  m1 The matrix to multiply with.
   * @param  m2 The matrix to multiply with.
   * @param  m1Row The starting row of m1.
   * @param  m1Col The starting col of m1.
   * @param  m2Row The starting row of m2.
   * @param  m2Col The starting col of m2.
   * @param  dimension The dimension of the input (sub)-matrices and the size
   *     of the output matrix.
  MatrixMultiplication(Matrix m1, Matrix m2, int m1Row, int m1Col, int m2Row,
                       int m2Col, int dimension) {
    this.m1 = m1;
    this.m2 = m2;
    this.m1Row = m1Row;
    this.m1Col = m1Col;
    this.m2Row = m2Row;
    this.m2Col = m2Col;
    this.dimension = dimension;

  public Matrix compute() {
   /* if (dimension == 1) {
      //Matrix result = new Matrix(1);
      return Matrix.nonRecursiveMultiply(m1,m2,m1Row,m1Col,m2Row,m2Col,dimension);
      //trivial and return the same type...

      //count the usually we do one and must be pre-existing

    if (dimension < FORK_THRESHOLD) {
      return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);

    int size = dimension / 2;
    Matrix result = new Matrix(dimension); 

    MatrixMultiplication mma11b11 = new MatrixMultiplication(m1,m2,m1Row,m1Col,
      m2Row,m2Col, size);

    MatrixMultiplication mma12b21 = new MatrixMultiplication(m1, m2, m1Row, 
      m1Col + size, m2Row + size, m2Col, size);

    MatrixMultiplication mma11b12 = new MatrixMultiplication(m1, m2, m1Row, m1Col, 
      m2Row, m2Col + size, size);

    MatrixMultiplication mma12b22 = new MatrixMultiplication(m1, m2, m1Row, 
      m1Col + size, m2Row + size, m2Col + size, size);

    MatrixMultiplication mma21b11 = new MatrixMultiplication(m1, m2, m1Row + size, 
      m1Col, m2Row, m2Col, size);

    MatrixMultiplication mma22b21 = new MatrixMultiplication(m1, m2, m1Row + size, 
      m1Col + size, m2Row + size, m2Col, size);

    MatrixMultiplication mma21b12 = new MatrixMultiplication(m1, m2, m1Row + size, 
      m1Col, m2Row, m2Col + size, size);

    MatrixMultiplication mma22b22 = new MatrixMultiplication(m1, m2, m1Row + size, 
      m1Col + size, m2Row + size, m2Col + size, size);


    Matrix a22b22 = mma22b22.join();
    Matrix a21b12 = mma21b12.join();
    Matrix a22b21 = mma22b21.join();
    Matrix a21b11 = mma21b11.join();
    Matrix a12b22 = mma12b22.join();
    Matrix a11b12 = mma11b12.join();
    Matrix a12b21 = mma12b21.join();
    Matrix a11b11 = mma11b11.join();

    for (int i = 0; i < size; i++) {
      double[] m1m = a21b12.m[i];
      double[] m2m = a22b22.m[i];
      double[] r1m = result.m[i + size];
      for (int j = 0; j < size; j++) {
        r1m[j + size] = m1m[j] + m2m[j];

    for (int i = 0; i < size; i++) {
      double[] m1m = a21b11.m[i];
      double[] m2m = a22b21.m[i];
      double[] r1m = result.m[i + size];
      for (int j = 0; j < size; j++) {
        r1m[j] = m1m[j] + m2m[j];

    for (int i = 0; i < size; i++) {
      double[] m1m = a11b12.m[i];
      double[] m2m = a12b22.m[i];
      double[] r1m = result.m[i];
      for (int j = 0; j < size; j++) {
        r1m[j + size] = m1m[j] + m2m[j];

    for (int i = 0; i < size; i++) {
      double[] m1m = a11b11.m[i];
      double[] m2m = a12b21.m[i];
      double[] r1m = result.m[i];
      for (int j = 0; j < size; j++) {
        r1m[j] = m1m[j] + m2m[j];

    return result; 


import java.util.function.Supplier;
import java.lang.StringBuilder;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.ForkJoinTask;
import java.lang.Runtime;

 * Encapsulate a square matrix of double values.
class Matrix {
   * 2D square array of double values, storing the matrix.
  double[][] m;
   * The number of columns and rows in the matrix.
  int dimension;

  private static final int THRESHOLD = 2;

   * Checks if two matrices are equals.
   * @param   m1  First matrices to check
   * @param   m2  Second matrices to check against
   * @return  true if every elements in m1 and m2 are the same; false otherwise.
  public static boolean equals(Matrix m1, Matrix m2) {
    if (m1.dimension != m2.dimension) {
      return false;
    for (int i = 0; i < m1.dimension; i++) {
      for (int j = 0; j < m1.dimension; j++) {
        if (Math.abs(m1.m[i][j] - m2.m[i][j]) > 0.000001) {
          return false;
    return true;

   * A constructor for the matrix.
   * @param  dimension The number of rows.
  Matrix(int dimension) {
    this.dimension = dimension;
    this.m = new double[dimension][dimension];

   * Generate a matrix of d x d according to the given supplier.
   * @param  dimension The dimension of the matrix
   * @param  supplier The lambda to generate the matrix with.
   * @return The new matrix.
  static Matrix generate(int dimension, Supplier<Double> supplier) {
    Matrix matrix = new Matrix(dimension);
    for (int row = 0; row < dimension; row++) {
      for (int col = 0; col < dimension; col++) {
        matrix.m[row][col] = supplier.get();
    return matrix;

   * Return a string representation of the matrix, pretty-printed
   * with each row on a single line.
   * @return The string representation of this matrix.
  public String toString() {
    StringBuilder s = new StringBuilder();
    for (int row = 0; row < dimension; row++) {
      for (int col = 0; col < dimension; col++) {
        s.append(String.format("%.4f", m[row][col]) + " ");
    return s.toString();

   * Multiply matrix m with this matrix, return a new result matrix.
   * @param  m1 The matrix to multiply with.
   * @param  m2 The matrix to multiply with.
   * @param  m1Row The starting row of m1.
   * @param  m1Col The starting col of m1.
   * @param  m2Row The starting row of m2.
   * @param  m2Col The starting col of m2.
   * @param  dimension The dimension of the input (sub)-matrices and the size
   *     of the output matrix.
   * @return The new matrix.
  public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2,
      int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {
    Matrix result = new Matrix(dimension);
    for (int row = 0; row < dimension; row++) {
      for (int col = 0; col < dimension; col++) {
        double sum = 0;
        // multiply row to col
        for (int i = 0; i < dimension; i++) {
          sum += m1.m[row + m1Row][i + m1Col] * m2.m[i + m2Row][col + m2Col];
        result.m[row][col] = sum;
    return result;

   * Multiple two matrices non-recursively.
   * @param m1 The matrix to multiply with.
   * @param m2 The matrix to multiply with.
   * @return The resulting matrix m1 * m2
  public static Matrix nonRecursiveMultiply(Matrix m1, Matrix m2) {
    return Matrix.nonRecursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);

   * Multiply matrix m with this matrix, return a new result matrix.
   * @param  m1 The matrix to multiply with.
   * @param  m2 The matrix to multiply with.
   * @param  m1Row The starting row of m1.
   * @param  m1Col The starting col of m1.
   * @param  m2Row The starting row of m2.
   * @param  m2Col The starting col of m2.
   * @param  dimension The dimension of the input (sub)-matrices and the size
   *     of the output matrix.
   * @return The resulting matrix m1 * m2
  public static Matrix recursiveMultiply(Matrix m1, Matrix m2,
      int m1Row, int m1Col, int m2Row, int m2Col, int dimension) {

    // If the matrix is small enough, just multiple non-recursively.
    if (dimension <= THRESHOLD) {
      return Matrix.nonRecursiveMultiply(m1, m2, m1Row, m1Col, m2Row, m2Col, dimension);

    // Else, cut the matrix into four blocks of equal size, recursively
    // multiply then sum the multiplication result.
    int size = dimension / 2;
    Matrix result = new Matrix(dimension);
    Matrix a11b11 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row,
        m2Col, size);
    Matrix a12b21 = recursiveMultiply(m1, m2, m1Row, m1Col + size,
        m2Row + size, m2Col, size);
    for (int i = 0; i < size; i++) {
      double[] m1m = a11b11.m[i];
      double[] m2m = a12b21.m[i];
      double[] r1m = result.m[i];
      for (int j = 0; j < size; j++) {
        r1m[j] = m1m[j] + m2m[j];

    Matrix a11b12 = recursiveMultiply(m1, m2, m1Row, m1Col, m2Row,
        m2Col + size, size);
    Matrix a12b22 = recursiveMultiply(m1, m2, m1Row, m1Col + size,
        m2Row + size, m2Col + size, size);
    for (int i = 0; i < size; i++) {
      double[] m1m = a11b12.m[i];
      double[] m2m = a12b22.m[i];
      double[] r1m = result.m[i];
      for (int j = 0; j < size; j++) {
        r1m[j + size] = m1m[j] + m2m[j];

    Matrix a21b11 = recursiveMultiply(m1, m2, m1Row + size, m1Col,
        m2Row, m2Col, size);
    Matrix a22b21 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size,
        m2Row + size, m2Col, size);
    for (int i = 0; i < size; i++) {
      double[] m1m = a21b11.m[i];
      double[] m2m = a22b21.m[i];
      double[] r1m = result.m[i + size];
      for (int j = 0; j < size; j++) {
        r1m[j] = m1m[j] + m2m[j];

    Matrix a21b12 = recursiveMultiply(m1, m2, m1Row + size, m1Col,
        m2Row, m2Col + size, size);
    Matrix a22b22 = recursiveMultiply(m1, m2, m1Row + size, m1Col + size,
        m2Row + size, m2Col + size, size);
    for (int i = 0; i < size; i++) {
      double[] m1m = a21b12.m[i];
      double[] m2m = a22b22.m[i];
      double[] r1m = result.m[i + size];
      for (int j = 0; j < size; j++) {
        r1m[j + size] = m1m[j] + m2m[j];
    return result;

   * Multiple two matrices recursively but sequentially with
   * divide-and-conquer algorithm.
   * @param m1 The matrix to multiply with.
   * @param m2 The matrix to multiply with.
   * @return The resulting matrix m1 * m2
  public static Matrix recursiveMultiply(Matrix m1, Matrix m2) {
    return Matrix.recursiveMultiply(m1, m2, 0, 0, 0, 0, m1.dimension);

   * Multiple two matrices recursively and parallely with
   * divide-and-conquer algorithm.
   * @param m1 The matrix to multiply with.
   * @param m2 The matrix to multiply with.
   * @return The resulting matrix m1 * m2
  public static Matrix parallelMultiply(Matrix m1, Matrix m2) {
    return new MatrixMultiplication(m1, m2, 0, 0, 0, 0, m1.dimension)

import java.util.function.Supplier;
import java.util.stream.DoubleStream;
import java.util.Random;
import java.util.Scanner;
import java.time.Instant;
import java.time.Duration;

编辑:在程序的输出中添加到/** * Main is the main driver class for testing matrix multiplication. * Usage: java Main n * 2^n is the dimension of the square matrixOne */ class Main { public static void main(String[] args) { int n = (new Scanner(System.in)).nextInt(); Random random = new Random(1); int dimension = 1 << n; System.out.println("dimension " + dimension); Matrix matrixOne = Matrix.generate(dimension, () -> random.nextDouble()); Matrix matrixTwo = Matrix.generate(dimension, () -> random.nextDouble()); Matrix result1 = Matrix.nonRecursiveMultiply(matrixOne, matrixTwo); Matrix result2 = Matrix.parallelMultiply(matrixOne, matrixTwo); boolean match = Matrix.equals(result1, result2); if (!match) { System.out.println("ERROR: matrix multiplication gives inconsistent " + "result in sequential and parallel implementations."); return; } double d1 = measureTimeToRun(() -> Matrix.nonRecursiveMultiply(matrixOne, matrixTwo)); double d2 = measureTimeToRun(() -> Matrix.parallelMultiply(matrixOne, matrixTwo)); System.out.printf("Parallel %.3f ms Sequential %.3f ms Speedup %.3f times\n", d2, d1, d1 / d2); } /** * Return the average time needed to run the task over three runs. * @param task A lambda expression for the task to be run * @return The average time taken in ms. */ private static double measureTimeToRun(Supplier<Matrix> task) { final int numOfTimes = 3; double sum = 0; for (int i = 0; i < numOfTimes; i++) { Instant start = Instant.now(); Matrix m = task.get(); Instant stop = Instant.now(); sum += Duration.between(start, stop).toMillis(); } return sum / numOfTimes; } } 的输出中,以防引起混乱。

0 个答案:
