优化MEX文件以加速大型多维数组的乘法(代码瓶颈)

时间:2018-02-24 23:25:04

标签: c optimization compiler-optimization mex

我正在尝试执行诸如将具有3200万个元素的7D阵列相乘的操作。我写了一个MEX文件,因为我认为这些操作在C中比在Matlab中更快。但是,我发现MEX文件的速度大约是在Matlab(2017b)中直接执行操作的两倍。

我想要执行的示例操作是:

T8  = rand(1,1e3,2,2,2,2,2);
wsm = rand(1e3,1e3,2,2);
CM  = bsxfun(@times,T8,wsm);

在我的机器上,这需要0.117065秒(我称之为,其他类似操作,每次运行模型约1000次,模型运行数千次以优化参数 - 这些操作使得优化过于缓慢)。

这是我写的MEX文件,它使用7 for循环通过线性索引访问T8和wsm的元素(也许我应该以更有效的方式访问元素或避免循环?):

#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    mwSize i, j, k, l, m, n, o, I, J, K, L, M, N, O;
    mwSize *dims,*dims1;
    double *T8, *wsm, *CM;
      T8  = mxGetPr(prhs[0]);
      wsm = mxGetPr(prhs[1]);

      dims = mxGetDimensions(prhs[0]);
      dims1 = mxGetDimensions(prhs[1]);
      dims[0] = dims1[0];

      I = dims[0];
      J = dims[1];
      K = dims[2];
      L = dims[3];
      M = dims[4];
      N = dims[5];
      O = dims[6];

      plhs[0] = mxCreateNumericArray(7,dims,mxDOUBLE_CLASS,mxREAL);
      CM = mxGetPr(plhs[0]);

      for( o=0; o<O; o++ ) {
          for( n=0; n<N; n++ ) {
              for( m=0; m<M; m++ ) {
                  for( l=0; l<L; l++ ) {
                      for( k=0; k<K; k++ ) {
                          for( j=0; j<J; j++ ) {
                              for( i=0; i<I; i++ ) {
                                  *CM++ = T8[j + k*J + +l*J*K + m*L*J*K + n*M*L*J*K + o*N*M*L*J*K] * wsm[i + j*I + k*I*J + l*I*J*K];
                              }
                          }
                      }
                  }
              }
          }
      }
}

当我调用上述MEX文件时

CM = arrayProduct(T8,wsm);

需要0.215211秒(几乎两倍)。

我的代码非常松散地基于此处建议的代码(https://uk.mathworks.com/matlabcentral/answers/210352-optimize-speed-up-a-big-and-slow-matrix-operation-with-addition-and-bsxfun)。

对于我可以采取哪些不同的方法来加快我的代码的任何建议将不胜感激!

1 个答案:

答案 0 :(得分:1)

假设你可以在这样的平凡矩阵数学中击败Matlab是一个很大的错误。 Matlab从一开始就进行了优化,以执行矩阵数学运算。

有时候有很好的理由来编写MEX函数,包括出于性能原因,但这通常是在纯matlab解决方案无法以最佳方式编写的情况下(例如,当您需要编写批量时)显式循环)。

您的代码可能比Matlab中已经存在的优化矩阵数学慢的两个主要原因是:

  1. Matlab可能会使用多个线程并行执行计算。你的代码没有,但真正的最佳解决方案可能会。
  2. 您可能在内存访问模式中犯了一个错误,导致缓存命中率较低。
  3. 另一种看待这种情况的方法是:如果Matlab无法以最佳方式实现乘法,那么人们是否会将其用于大数据集的严格数学运算?有Matlab不知道的算法,有时可以使用MEX加速,但乘法不是其中之一。