我有一个MATLAB代码
%% Inputs are theta and h (size NxM)
alpha=zeros(N,M);
h_tmp=zeros(N,M);
h_tmp(1:N-1,:)=h(2:N ,:);
for i=1:N
alpha(i,:)=theta.*(h_tmp(i,:)+h(i,:));
end
通过使用矢量化方法,上面的代码可以是
alpha = theta .* [h(1:N-1,:) + h(2:N,:); h(N,:)];
为了加速代码,我想用C ++在MEX文件中重写它。二维数组中MATLAB和C ++的主要区别是行主顺序(MATLAB)和列主顺序(C ++)
double *h, *alpha, *h_temp;
int N,M;
double theta;
N = (int) mxGetN(prhs[0]); //cols
M = (int) mxGetM(prhs[0]); //rows
h = (double *)mxGetData(prhs[0]);
theta = (double)*mxGetPr(prhs[1]);
/* Initial zeros matrix*/
plhs[0] = mxCreateDoubleMatrix(M, N, mxREAL); alpha = mxGetPr(plhs[0]);
//////////////Compute alpha/////////
for (int rows=0; rows < M; rows++) {
//h[N*rows+cols] is h_tmp
for (int cols=0; cols < N; cols++) {
alpha[N*rows+cols]=theta*(h[N*rows+cols+1]+h[N*rows+cols]);
}
}
我的Mex代码和MATLAB代码是否相同?如果没有,你能帮我解决一下吗?
答案 0 :(得分:1)
除了评论对您的问题的更正外,还有一个小的差别。缺少的是你在Matlab代码中跳过h(N,:)
,在代码的C代码迭代中完成,直到cols < N
,这(由于C中的0索引)也将处理最后一个元素在每一栏中。
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
double *h, *alpha, *h_temp;
int num_columns, num_rows;
double theta;
num_columns = (int) mxGetN(prhs[0]); //cols
num_rows = (int) mxGetM(prhs[0]); //rows
h = (double *)mxGetData(prhs[0]);
theta = (double)*mxGetPr(prhs[1]);
/* Initial zeros matrix*/
plhs[0] = mxCreateDoubleMatrix(num_rows, num_columns, mxREAL); alpha = mxGetPr(plhs[0]);
//////////////Compute alpha/////////
// there are num_rows many elements in each column
// and num_columns many rows. Matlab stores column first.
// h[0] ... h[num_rows-1] == h(:,1)
int idx; // to help make code cleaner
for (int column_idx=0; column_idx < num_columns; column_idx++) {
//iterate over each column
for (int row_idx=0; row_idx < num_rows-1; row_idx++) {// exclude h(end,row_idx)
//for each row in a column do
idx = num_columns * column_idx + row_idx;
alpha[idx]= theta * (h[idx+1] + h[idx]);
}
}
//the last column wasn't modified and set to 0 upon initialization.
//set it now
for(int rows = 0; rows < num_rows; rows++) {
alpha[num_columns*rows+(num_rows-1)] = theta * h[num_columns*rows+(num_rows-1)];
}
}
请注意,我决定重命名一些变量,这样我觉得它变得更容易阅读。
修改:删除了prhs[0] = plhs[0]
的建议,正如此答案的评论中所建议的那样。在某些情况下,人们可能会逃避这种情况,但一般来说,在编写matlab .mex函数时这不是一个好习惯,它可能会使Matlab崩溃。