多线程矩阵乘法

时间:2016-02-15 12:01:18

标签: java multithreading

我最近开始在java中学习多线程。由于我在我的大学编写了一个数值计算程序,我决定通过编程多线程矩阵乘法进行初步尝试。

这是我的代码。请记住,这只是第一次尝试,并不是很干净。

    public class MultithreadingTest{

        public static void main(String[] args) {
            // TODO Auto-generated method stub
            double[][] matrix1 = randomSquareMatrix(2000);
            double[][] matrix2 = randomSquareMatrix(2000);

            matrixMultiplication(matrix1,matrix2,true);
            matrixMultiplicationSingleThread(matrix1, matrix2);
            try {
                matrixMultiplicationParallel(matrix1,matrix2, true);
            } catch (InterruptedException | ExecutionException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
            try {
                matrixMultiplicationParallel2(matrix1,matrix2, true);
            } catch (InterruptedException | ExecutionException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }

        }

        public static double[][] randomSquareMatrix(int n){
            double[][] mat = new double[n][n];
            Random rand = new Random();
            for(int i=0; i<n; i++) for(int j=0; j<n; j++) mat[i][j]=rand.nextInt(10);
            return mat;
        }
        public static void printSquareMat(double[][] mat){
            int n=mat.length;
            for(int i=0; i<n; i++){ for(int j=0; j<n; j++) System.out.print(mat[i][j]+" "); System.out.print("\n");}
            System.out.print("\n");
        }

        public static void average(double[][] matrix)
        {
            int n=matrix.length;
            double sum=0;
            for(int i=0; i<n; i++) for(int j=0; j<n; j++) sum+=matrix[i][j];

            System.out.println("Average of all Elements of Matrix : "+(sum/(n*n)));
        }

        public static void matrixMultiplication(double[][] matrix1, double[][] matrix2, boolean printMatrix){

            int n=matrix1.length;
            double[][] resultMatrix = new double[n][n];

            double startTime = System.currentTimeMillis();

            for(int i=0; i<n; i++)for(int j=0; j<n; j++)for(int k=0; k<n; k++) resultMatrix[i][j]+=matrix1[i][k]*matrix2[k][j];


            if (printMatrix && n<=5)for(int i=0; i<n; i++){for(int j=0; j<n; j++) System.out.print(resultMatrix[i][j]+" ");System.out.print("\n"); }

            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" in main thread.");
            average(resultMatrix);
        }

        public static void matrixMultiplicationSingleThread(double[][] m1, double[][] m2)
        {
            int n=m1.length;
            double startTime = System.currentTimeMillis();
            Thread t = new Thread(new multiSingle(m1,m2));
            t.start();
            try {
                t.join();
            } catch (InterruptedException e) {
                // TODO Auto-generated catch block
                System.out.println("Error");
                e.printStackTrace();
            }
            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" in external Thread.");

        }

        public static void matrixMultiplicationParallel(double[][] matrix1, double[][] matrix2, boolean printMatrix) throws InterruptedException, ExecutionException{

            int n=matrix1.length;
            double[][] resultMatrix=new double[n][n];
            double tmp;
            ExecutorService exe = Executors.newFixedThreadPool(2);
            Future<Double>[][] result = new Future[n][n];
            double startTime = System.currentTimeMillis();
            for(int i=0; i<n; i++)
            {
                for(int j=0; j<=i; j++)
                {
                    tmp=matrix2[i][j];
                    matrix2[i][j]=matrix2[j][i];
                    matrix2[j][i]=tmp;
                }
            }

            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    result[i][j] = exe.submit(new multi(matrix1[i],matrix2[j]));
                }
            }

            exe.shutdown();
            exe.awaitTermination(1, TimeUnit.DAYS);

            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    resultMatrix[i][j] = result[i][j].get();
                }
            }
            for(int i=0; i<n; i++)
            {
                for(int j=0; j<=i; j++)
                {
                    tmp=matrix2[i][j];
                    matrix2[i][j]=matrix2[j][i];
                    matrix2[j][i]=tmp;
                }
            }
            if (printMatrix && n<=5)for(int i=0; i<n; i++){for(int j=0; j<n; j++) System.out.print(resultMatrix[i][j]+" ");System.out.print("\n"); }

            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" multithreaded with algorithm 1.");
            average(resultMatrix);
        }

        public static void matrixMultiplicationParallel2(double[][] matrix1, double[][] matrix2, boolean printMatrix) throws InterruptedException, ExecutionException{

            int n=matrix1.length;
            double[][] resultMatrix=new double[n][n];
            double tmp;
            ExecutorService exe = Executors.newFixedThreadPool(2);
            Future<Double>[][] result = new Future[n][n];
            double startTime = System.currentTimeMillis();


            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    result[i][j] = exe.submit(new multi2(i,j,matrix1,matrix2));
                }
            }

            exe.shutdown();

            exe.awaitTermination(1, TimeUnit.DAYS);


            for(int i=0; i<n; i++)
            {
                for(int j=0; j<n; j++)
                {
                    resultMatrix[i][j] = result[i][j].get();
                }
            }

            if (printMatrix && n<=5)for(int i=0; i<n; i++){for(int j=0; j<n; j++) System.out.print(resultMatrix[i][j]+" ");System.out.print("\n"); }

            System.out.print("\n");
            System.out.println(((System.currentTimeMillis()-startTime)/1000)+
                    " seconds for matrix of size "+n+" multithreaded with algorithm 2.");
            average(resultMatrix);
        }

        public static class multi implements Callable<Double>{

            multi(double[] vec1, double[] vec2){
                this.vec1=vec1; this.vec2=vec2;
            }
            double result;
            double[] vec1, vec2;

            @Override
            public Double call() {
                result=0;
                for(int i=0; i<vec1.length; i++) result+=vec1[i]*vec2[i];
                return result;
            }
        }

        public static class multi2 implements Callable<Double>{

            multi2(int a, int b, double[][] vec1, double[][] vec2){
                this.a=a; this.b=b; this.vec1=vec1; this.vec2=vec2;
            }
            int a,b;
            double result;
            double[][] vec1, vec2;

            @Override
            public Double call() {
                result=0;
                for(int i=0; i<vec1.length; i++) result+=vec1[a][i]*vec2[i][b];
                return result;
            }
        }

        public static class multiSingle implements Runnable{

            double[][] matrix1, matrix2;

            multiSingle(double[][] m1, double[][] m2){
                matrix1=m1;
                matrix2=m2;
            }
            public static void matrixMultiplication(double[][] matrix1, double[][] matrix2, boolean printMatrix){

                int n=matrix1.length;
                double[][] resultMatrix = new double[n][n];

                for(int i=0; i<n; i++)for(int j=0; j<n; j++)for(int k=0; k<n; k++) resultMatrix[i][j]+=matrix1[i][k]*matrix2[k][j];

                MultithreadingTest.average(resultMatrix);
            }

            @Override
            public void run() {
                matrixMultiplication(matrix1, matrix2, false);
            }
        }

    }

我对多线程有两个一般性的问题,我希望没有为此开一个新主题。

  1. 有没有办法在没有附加类的情况下为实现runnable或callable的线程编写代码?我查看了使用匿名内部类和lambdas的方法,但据我有fount信息,我不能以这种方式将参数传递给线程,因为run()和call()不接受任何,这是,除非参数是最终的。但是假设我为矩阵运算编写了一个类,我宁愿不为我想在一个线程中运行的每个操作编写一个additinal类。
  2. 假设我的类执行了许多多线程操作,创建一个新的线程池并在每个方法中关闭它会浪费大量资源,我想。所以我想创建一个线程池作为我的类的成员,在需要时使用invokeAll实例化它。但如果我的对象被删除会发生什么?我是否会因为我从未关闭线程池而遇到问题?在C ++中,我会使用析构函数。或者gc在这种情况下是否会处理所有事情?
  3. 现在直接隐瞒我的代码:

    我以四种不同的方式实现了矩阵乘法,作为在我的主线程中运行的方法,作为在新线程中运行的方法,但仍然没有多线程(以确保我的主线程中不会有任何后台操作减慢它) ,以及两种不同的多线程矩阵乘法方式。第一个版本转换第二个矩阵,将乘法作为向量 - 向量乘法提交,并将矩阵转换回其原始形式。第二个版本直接采用矩阵,另外还有两个索引来定义矢量矢量乘法矩阵的行和列。

    对于所有版本,我测量了乘法所需的时间,并计算了得到的矩阵的平均值,以查看结果是否相同。

    我在两台计算机上运行此代码,包括相同的JVM和Windows 10.第一台是我的笔记本电脑,i5第5代,2,6 Ghz双核,第二台是我的台式电脑, i5第4代,4,2 Ghz四核。

    我希望我的台式电脑更快。我还期望多线程版本占用了单元线程版本的大约一半/四分之一的时间,但更多的是因为还有额外的工作来创建线程等。最后,我期望第二个多线程版本,它不转置一个矩阵两次,速度更快,因为操作较少。

    运行代码后,我对结果有点困惑,希望有人能向我解释一下:

    对于两种单线程方法,我的笔记本电脑需要大约340s(矩阵大小为3000)。所以我假设在我的主线程中没有完成昂贵的后台任务。另一方面,我的桌面PC需要440s。现在的问题是,为什么我的笔记本电脑速度更快,速度更快?即使第五代比第四代更快,因为我的台式电脑以我的笔记本电脑的1.6倍的速度运行,我仍然期望它更快。这些世代之间的差异不太大。

    对于多线程方法,我的笔记本电脑需要大约34秒。如果多线程是完美的,那么它不应该少于一半。为什么它在两个线程上快十倍?我的台式电脑也是如此。使用四个线程,乘法在16s而不是440s完成。这就像我的桌面PC工作速度与我的笔记本电脑一样,只有四个而不是两个线程。

    现在,对于两个多线程方法之间的比较,两次转换一个矩阵的版本在我的笔记本电脑上大约需要34秒,直接占用矩阵的版本需要大约200秒。这听起来很现实,因为它超过单线程方法的一半。但为什么它比第一个版本慢得多?我会假设两次转置矩阵比获取矩阵元素的额外时间慢?是否有一些我缺失的东西或正在使用矩阵比使用矢量慢得多?

    我希望有人能回答这些问题。很抱歉写这么长的帖子。

    您诚挚的 的Thorsten

2 个答案:

答案 0 :(得分:3)

这个大问题的答案:矩阵乘法所需的时间主要是将数据从RAM移动到CPU缓存所花费的时间。您可能有4个内核,但只有1个RAM总线,因此如果它们都相互阻塞等待内存访问,那么使用更多内核(多线程)将无法获得任何好处。

您应该尝试的第一个实验是:使用矩阵转置和向量乘法编写单线程版本。您会发现它的速度要快得多 - 可能与使用转置的多线程版本一样快。

原始单线程版本之所以如此之慢,是因为它必须为列中的每个单元格加载一个缓存块。如果你使用矩阵转置,那么所有这些单元格在内存中是顺序的,加载一个块会得到一堆。

因此,如果您想优化矩阵乘法,FIRST优化内存访问以提高缓存效率,那么就可以在几个线程之间划分工作 - 不超过核心数量的两倍。通过上下文切换等等,只会浪费时间和资源。

关于您的其他问题:

1)使用lambdas可以方便地从创建它们的范围中捕获变量,例如:

for(int i=0; i<n; i++)
{
    for(int j=0; j<n; j++)
    {
        final double[] v1 = matrix1[i];
        final double[] v2 = matrix2[j];
        result[i][j] = exe.submit(() -> vecdot(v1,v2));
    }
}

2)GC将负责处理。您不需要显式关闭线程池来释放任何资源。

答案 1 :(得分:1)

您必须小心,以尽量减少创建线程的开销。一个好的appraoch是使用ForkJoin框架使用线程池分割问题。这个框架

  • 重用现有的线程池。
  • 分解任务,直到有足够的时间让游泳池保持忙碌但不再有。

每个核心只有一个浮点单元,因此您的可扩展性将基于您拥有的核心数量。

我建议您阅读Fork Join Matrix Multiplication in Java我无法找到此代码的原始来源。

http://gee.cs.oswego.edu/dl/papers/fj.pdf

http://gee.cs.oswego.edu/dl/cpjslides/fj.pdf使用ForkJoin框架。