如何用mex代码表示MATLAB的二维数组

时间:2016-03-29 06:47:04

标签: c++ matlab mex

我有一个MATLAB代码

%% Inputs are theta and h (size NxM)
alpha=zeros(N,M);
h_tmp=zeros(N,M);
h_tmp(1:N-1,:)=h(2:N ,:);
for i=1:N
    alpha(i,:)=theta.*(h_tmp(i,:)+h(i,:));
end

通过使用矢量化方法,上面的代码可以是

alpha = theta .* [h(1:N-1,:) + h(2:N,:); h(N,:)];

为了加速代码,我想用C ++在MEX文件中重写它。二维数组中MATLAB和C ++的主要区别是行主顺序(MATLAB)和列主顺序(C ++)

double  *h, *alpha, *h_temp;
int N,M;
double theta;    
N      = (int) mxGetN(prhs[0]); //cols
M      = (int) mxGetM(prhs[0]); //rows
h      = (double *)mxGetData(prhs[0]);
theta  = (double)*mxGetPr(prhs[1]);
/* Initial zeros matrix*/
plhs[0]   = mxCreateDoubleMatrix(M, N, mxREAL);  alpha = mxGetPr(plhs[0]);
//////////////Compute alpha/////////    
for (int rows=0; rows < M; rows++) {
    //h[N*rows+cols] is h_tmp
    for (int cols=0; cols < N; cols++) {        
         alpha[N*rows+cols]=theta*(h[N*rows+cols+1]+h[N*rows+cols]);
    }
}

我的Mex代码和MATLAB代码是否相同?如果没有,你能帮我解决一下吗?

1 个答案:

答案 0 :(得分:1)

除了评论对您的问题的更正外,还有一个小的差别。缺少的是你在Matlab代码中跳过h(N,:),在代码的C代码迭代中完成,直到cols < N,这(由于C中的0索引)也将处理最后一个元素在每一栏中。

#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
    double  *h, *alpha, *h_temp;
    int num_columns, num_rows;
    double theta;    
    num_columns = (int) mxGetN(prhs[0]); //cols
    num_rows    = (int) mxGetM(prhs[0]); //rows
    h           = (double *)mxGetData(prhs[0]);
    theta       = (double)*mxGetPr(prhs[1]);
    /* Initial zeros matrix*/
    plhs[0]   = mxCreateDoubleMatrix(num_rows, num_columns, mxREAL);  alpha = mxGetPr(plhs[0]);
    //////////////Compute alpha/////////
    // there are num_rows many elements in each column
    // and num_columns many rows. Matlab stores column first.
    // h[0] ... h[num_rows-1] == h(:,1)
    int idx; // to help make code cleaner
    for (int column_idx=0; column_idx < num_columns; column_idx++) {
        //iterate over each column
        for (int row_idx=0; row_idx < num_rows-1; row_idx++) {// exclude h(end,row_idx)
            //for each row in a column do
            idx = num_columns * column_idx + row_idx;
            alpha[idx]= theta * (h[idx+1] + h[idx]);
        }
    }
    //the last column wasn't modified and set to 0 upon initialization.
    //set it now
    for(int rows = 0; rows < num_rows; rows++) {
        alpha[num_columns*rows+(num_rows-1)] = theta * h[num_columns*rows+(num_rows-1)];
    }
}

请注意,我决定重命名一些变量,这样我觉得它变得更容易阅读。

修改:删除了prhs[0] = plhs[0]的建议,正如此答案的评论中所建议的那样。在某些情况下,人们可能会逃避这种情况,但一般来说,在编写matlab .mex函数时这不是一个好习惯,它可能会使Matlab崩溃。