多线程矩阵乘法

时间:2009-10-15 18:07:11

标签: java multithreading

我编写了一个多线程矩阵乘法。我相信我的方法是正确的,但我不是百分百肯定。关于线程,我不明白为什么我不能只运行(new MatrixThread(...)).start()而不是使用ExecutorService

此外,当我对多线程方法与经典方法进行基准测试时,经典方法很多更快......

我做错了什么?

Matrix Class:

import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

class Matrix
{
   private int dimension;
   private int[][] template;

   public Matrix(int dimension)
   {
      this.template = new int[dimension][dimension];
      this.dimension = template.length;
   }

   public Matrix(int[][] array) 
   {
      this.dimension = array.length;
      this.template = array;      
   }

   public int getMatrixDimension() { return this.dimension; }

   public int[][] getArray() { return this.template; }

   public void fillMatrix()
   {
      Random randomNumber = new Random();
      for(int i = 0; i < dimension; i++)
      {
         for(int j = 0; j < dimension; j++)
         {
            template[i][j] = randomNumber.nextInt(10) + 1;
         }
      }
   }

   @Override
   public String toString()
   {
      String retString = "";
      for(int i = 0; i < this.getMatrixDimension(); i++)
      {
         for(int j = 0; j < this.getMatrixDimension(); j++)
         {
            retString += " " + this.getArray()[i][j];
         }
         retString += "\n";
      }
      return retString;
   }

   public static Matrix classicalMultiplication(Matrix a, Matrix b)
   {      
      int[][] result = new int[a.dimension][b.dimension];
      for(int i = 0; i < a.dimension; i++)
      {
         for(int j = 0; j < b.dimension; j++)
         {
            for(int k = 0; k < b.dimension; k++)
            {
               result[i][j] += a.template[i][k] * b.template[k][j];
            }
         }
      }
      return new Matrix(result);
   }

   public Matrix multiply(Matrix multiplier) throws InterruptedException
   {
      Matrix result = new Matrix(dimension);
      ExecutorService es = Executors.newFixedThreadPool(dimension*dimension);
      for(int currRow = 0; currRow < multiplier.dimension; currRow++)
      {
         for(int currCol = 0; currCol < multiplier.dimension; currCol++)
         {            
            //(new MatrixThread(this, multiplier, currRow, currCol, result)).start();            
            es.execute(new MatrixThread(this, multiplier, currRow, currCol, result));
         }
      }
      es.shutdown();
      es.awaitTermination(2, TimeUnit.DAYS);
      return result;
   }

   private class MatrixThread extends Thread
   {
      private Matrix a, b, result;
      private int row, col;      

      private MatrixThread(Matrix a, Matrix b, int row, int col, Matrix result)
      {         
         this.a = a;
         this.b = b;
         this.row = row;
         this.col = col;
         this.result = result;
      }

      @Override
      public void run()
      {
         int cellResult = 0;
         for (int i = 0; i < a.getMatrixDimension(); i++)
            cellResult += a.template[row][i] * b.template[i][col];

         result.template[row][col] = cellResult;
      }
   }
} 

主要课程:

import java.util.Scanner;

public class MatrixDriver
{
   private static final Scanner kb = new Scanner(System.in);

   public static void main(String[] args) throws InterruptedException
   {      
      Matrix first, second;
      long timeLastChanged,timeNow;
      double elapsedTime;

      System.out.print("Enter value of n (must be a power of 2):");
      int n = kb.nextInt();

      first = new Matrix(n);
      first.fillMatrix();      
      second = new Matrix(n);
      second.fillMatrix();

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using threads:\n" +
                                                        first.multiply(second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Threaded took "+elapsedTime+" seconds");

      timeLastChanged = System.currentTimeMillis();
      //System.out.println("Product of the two using classical:\n" +
                                  Matrix.classicalMultiplication(first,second);
      timeNow = System.currentTimeMillis();
      elapsedTime = (timeNow - timeLastChanged)/1000.0;
      System.out.println("Classical took "+elapsedTime+" seconds");
   }
} 

P.S。如果需要进一步澄清,请告诉我。

3 个答案:

答案 0 :(得分:6)

即使使用ExecutorService,创建线程也会涉及很多开销。我怀疑为什么你的多线程方法是如此缓慢的原因是你花了99%创建一个新的线程,只有1%或更少,做实际数学。

通常,要解决此问题,您需要将一大堆操作一起批处理并在单个线程上运行它们。在这种情况下,我不是100%如何做到这一点,但我建议将矩阵分成更小的块(比如10个更小的矩阵)并在线程上运行,而不是在自己的线程中运行每个单元。

答案 1 :(得分:5)

你创造了很多线程。创建线程不仅昂贵,而且对于CPU绑定应用程序,您不需要比可用处理器更多的线程(如果这样做,您必须花费线程之间的处理能力切换,这也可能导致缓存错过非常昂贵的)。

也没有必要向[{1}}发送一个帖子;它只需要一个execute。通过应用这些更改,您将获得巨大的性能提升:

  1. 使Runnable成为静态成员,为当前处理器调整大小,然后发送ExecutorService,以便ThreadFactory完成后不会使程序继续运行。 (将它作为参数发送到方法而不是将其保存为静态字段可能在架构上更清晰;我将其留作读者的练习。☺)

    main
  2. private static final ExecutorService workerPool = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), new ThreadFactory() { public Thread newThread(Runnable r) { Thread t = new Thread(r); t.setDaemon(true); return t; } }); 实施MatrixThread而不是继承Runnable。线程创建起来很昂贵; POJO非常便宜。您还可以使Thread使实例更小(因为非静态类获得对封闭对象的隐式引用)。

    static
  3. 从更改(1)开始,您不能再private static class MatrixThread implements Runnable 来确保所有任务都已完成(作为此工作池)。而是使用返回awaitTermination的{​​{1}}方法。收集列表中的所有未来对象,当您提交了所有任务时,迭代列表并为每个对象调用submit

  4. 您的Future<?>方法现在看起来应该是这样的:

    get

    它会比单线程版本更快吗?好吧,在我可以说是糟糕的盒子上,多线程版本的值multiply&lt; public Matrix multiply(Matrix multiplier) throws InterruptedException { Matrix result = new Matrix(dimension); List<Future<?>> futures = new ArrayList<Future<?>>(); for(int currRow = 0; currRow < multiplier.dimension; currRow++) { for(int currCol = 0; currCol < multiplier.dimension; currCol++) { Runnable worker = new MatrixThread(this, multiplier, currRow, currCol, result); futures.add(workerPool.submit(worker)); } } for (Future<?> f : futures) { try { f.get(); } catch (ExecutionException e){ throw new RuntimeException(e); // shouldn't happen, but might do } } return result; } 1024.

    但这只是表面上的问题。 真正的问题是你创建了{em>批的n个实例 - 你的内存消耗是MatrixThread,这是一个非常糟糕的标志即可。将内部for循环移动到O(n²)会使性能提高 craploads (理想情况下,您不会创建比工作线程更多的任务)。


    编辑:由于我有更多紧迫的事情要做,我无法抗拒进一步优化。我想出了这个(......极其丑陋的代码片段),“只”创造了MatrixThread.run个工作:

    O(n)

    它仍然不是很好,但基本上多线程版本可以计算你耐心等待的任何东西,并且它比单线程版本更快。

答案 2 :(得分:1)

首先,你应该在你使用的四核上使用你所拥有的核心大小的newFixedThreadPool 4.其次,不要为每个矩阵创建一个新核心。

如果使executorservice成为一个静态成员变量,我几乎可以更快地执行矩阵大小为512的线程版本。

另外,更改MatrixThread以实现Runnable而不是扩展Thread也可以将执行速度加快到我的机器2x上的速度,以及512