Matlab在计算大型数组的mex函数时花费了太多时间

时间:2016-05-02 07:20:10

标签: arrays matlab optimization

我编写了一个MATLAB脚本,其中我传递了几个标量和一个行向量作为mex函数的输入参数,并且在进行一些计算之后,它返回一个标量作为输出。必须对大小为1 X 1638400的数组的所有元素执行此过程。以下是相应的代码:

ans=0;
for i=0:1638400-1
    temp = sub_imed(r,i,diff);
    ans  = ans + temp*diff(i+1); 
end

其中r,i是标量,diff是大小为1 X 1638400的向量,sub_imed是执行以下工作的MEX函数:

void sub_imed(double r,mwSize base, double* diff, mwSize dim, double* ans)              
{                                                                                           
     mwSize i,k,l,k1,l1;
     double d,g,temp;

     for(i=0; i<dim; i++)
     {   
          k = (base/200) + 1;
          l = (base%200) + 1;
          k1 = (i/200) + 1;
          l1 = (i%200) + 1;

          d = sqrt(pow((k-k1),2) + pow((l-l1),2));

          g=(1/(2*pi*pow(r,2)))*exp(-(pow(d,2))/(2*(pow(r,2))));   

          temp = temp + diff[i]*g;
     }

     *ans  = temp;
}

void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[]) 
{
    double *diff;           /* Input data vectors */
    double r;               /* Value of r (input) */
    double* ans;            /* Output ImED distance */
    size_t base,ncols;      /* For storing the size of input vector and base */

    /* Checking for proper number of arguments */
    if(nrhs!=3) 
       mexErrMsgTxt("Error..Three inputs required.");

    if(nlhs!=1) 
       mexErrMsgTxt("Error..Only one output required.");

    /* make sure the first input argument(value of r) is scalar */
    if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1) 
       mexErrMsgTxt("Error..Value of r must be a scalar."); 

    /* make sure that the input value of base is a scalar */
    if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1) 
       mexErrMsgTxt("Error..Value of base must be a scalar."); 

    /* make sure that the input vector diff is of type double */
    if(!mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]))    
       mexErrMsgTxt("Error..Input vector must be of type double.");

    /* check that number of rows in input arguments is 1 */
    if(mxGetM(prhs[2])!=1) 
       mexErrMsgTxt("Error..Inputs must be row vectors."); 

    /* Get the value of r */
    r = mxGetScalar(prhs[0]);
    base = mxGetScalar(prhs[1]);

    /* Getting the input vectors */
    diff = mxGetPr(prhs[2]);
    ncols = mxGetN(prhs[2]);

    /* Creating link for the scalar output */
    plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL);
    ans = mxGetPr(plhs[0]); 

    sub_imed(r,base,diff,(mwSize)ncols,ans);
}

有关问题和下划线算法的更多详细信息,请按照主题Euclidean distance between images

进行操作

我对我的MATLAB脚本进行了剖析,并了解它需要63秒。只对387次调用sub_imed()mex函数。因此,对于1638400次sub_imed的调用,理想情况下需要大约74小时,这太长了。

有人可以通过建议一些减少计算时间的替代方法来帮助我优化代码。

提前致谢。

2 个答案:

答案 0 :(得分:3)

我将您的代码移植回MATLAB并做了一些小调整,而结果应该保持不变。我介绍了以下常量:

N = 8192;
step = 0.005;

请注意N / step = 1638400。这样,您就可以重写变量k(并将其重命名为baseDiv):

baseDiv = 1 + (0 : step : (N-step)).';

即。它是1:8193,步长为0.005。 同样,l1:200=1:(1/0.005)),连续重复8192次,即(现称为baseMod):

baseMod = (repmat(1:1:(1/step), 1, N)).';

您的变量k1l1只是ik的{​​{1}}元素,即l和{{1} }。

使用向量baseDiv(i)baseMod(i),可以使用

计算baseDivbaseMod和临时变量d
g

我们可以把它放到你的MATLAB for循环中,这样整个程序就变成了

tmp

通过消除内部for循环并以矢量化方式计算它,1000次迭代仍然需要11秒,因此总运行时间为5小时。仍然 - 加速超过 10x 。为了获得更高的加速速度,您有两种可能性:

1)完成矢量化:您可以使用d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2); g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2)); tmp = sum(diffVec .* g); % Constants N = 8192; step = 0.005; % Some example data r = 2; diffVec = rand(N/step,1); base = (0:(numel(diffVec)-1)).'; baseDiv = (1:step:1+N-step).'; baseMod = (repmat(1:1:(1/step), 1, N)).'; res = 0; for k=1:(N/step) d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2); g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2)); tmp = sum(diffVec .* g); res = res + tmp * diffVec(k); end 在列上轻松地对剩余的for循环进行矢量化,以计算 all 值同时。不幸的是我们遇到了一个小问题: 1638400-by-1638400 双矩阵将占用20TB的RAM,我认为 - 你的笔记本电脑没有; - )

2)更少的样本:您正在进行一些分辨率为bsxfun(@minus, baseDiv, baseDiv.')的数学变换。检查你是否真的,真的需要这种精确度!如果你取精度的1/10:sum,你的速度要快100倍,并且在 3分钟内完成

答案 1 :(得分:0)

  • 您是否使用优化标志进行编译? (见旗pow(x,2)
  • 您应该编写x*x
  • ,而不是在C / C ++中使用mexFunction
  • 我几乎可以肯定它之所以变慢是因为你所展示的代码,而是因为你在diff中做了什么。如果我不得不猜测我说你在mex -O myfile.cpp中无意义地复制了内存,但我们需要确保完整的Mex功能。

使用void sub_imed( double r, size_t base, const double *diff, size_t dim, double& ans) { double d, g; // these need to be double to avoid underflow double k = base / 200; double l = base % 200; r = 2*r*r; for(; dim; --dim, ++diff ) { d = k - i/200; g = l - i%200; ans += (*diff) * exp( - (d*d + g*g)/r ) / (pi*r); } } void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[]) { /* Checking for proper number of arguments */ if(nrhs!=3) mexErrMsgTxt("Error..Three inputs required."); if(nlhs!=1) mexErrMsgTxt("Error..Only one output required."); /* make sure the first input argument(value of r) is scalar */ if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1 ) mexErrMsgTxt("Error..Value of r must be a scalar."); /* make sure that the input value of base is a scalar */ if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1 ) mexErrMsgTxt("Error..Value of base must be a scalar."); /* make sure that the input vector diff is of type double */ if( !mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]) ) mexErrMsgTxt("Error..Input vector must be of type double."); /* check that number of rows in input arguments is 1 */ if( mxGetM(prhs[2])!=1 ) mexErrMsgTxt("Error..Inputs must be row vectors."); /* Get the value of r */ double r = mxGetScalar(prhs[0]); size_t base = static_cast<size_t>(mxGetScalar(prhs[1]); /* Getting the input vectors */ const double *diff = mxGetPr(prhs[2]); size_t nrows = static_cast<size_t>(mxGetN(prhs[2])); /* Creating link for the scalar output */ plhs[0] = mxCreateDoubleScalar(0.0); sub_imed( r, base, diff, nrows, *mxGetPr(plhs[0]) ); } 尝试以下C ++代码:

{{1}}