我编写了一个MATLAB脚本,其中我传递了几个标量和一个行向量作为mex函数的输入参数,并且在进行一些计算之后,它返回一个标量作为输出。必须对大小为1 X 1638400的数组的所有元素执行此过程。以下是相应的代码:
ans=0;
for i=0:1638400-1
temp = sub_imed(r,i,diff);
ans = ans + temp*diff(i+1);
end
其中r,i是标量,diff是大小为1 X 1638400的向量,sub_imed是执行以下工作的MEX函数:
void sub_imed(double r,mwSize base, double* diff, mwSize dim, double* ans)
{
mwSize i,k,l,k1,l1;
double d,g,temp;
for(i=0; i<dim; i++)
{
k = (base/200) + 1;
l = (base%200) + 1;
k1 = (i/200) + 1;
l1 = (i%200) + 1;
d = sqrt(pow((k-k1),2) + pow((l-l1),2));
g=(1/(2*pi*pow(r,2)))*exp(-(pow(d,2))/(2*(pow(r,2))));
temp = temp + diff[i]*g;
}
*ans = temp;
}
void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[])
{
double *diff; /* Input data vectors */
double r; /* Value of r (input) */
double* ans; /* Output ImED distance */
size_t base,ncols; /* For storing the size of input vector and base */
/* Checking for proper number of arguments */
if(nrhs!=3)
mexErrMsgTxt("Error..Three inputs required.");
if(nlhs!=1)
mexErrMsgTxt("Error..Only one output required.");
/* make sure the first input argument(value of r) is scalar */
if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1)
mexErrMsgTxt("Error..Value of r must be a scalar.");
/* make sure that the input value of base is a scalar */
if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1)
mexErrMsgTxt("Error..Value of base must be a scalar.");
/* make sure that the input vector diff is of type double */
if(!mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]))
mexErrMsgTxt("Error..Input vector must be of type double.");
/* check that number of rows in input arguments is 1 */
if(mxGetM(prhs[2])!=1)
mexErrMsgTxt("Error..Inputs must be row vectors.");
/* Get the value of r */
r = mxGetScalar(prhs[0]);
base = mxGetScalar(prhs[1]);
/* Getting the input vectors */
diff = mxGetPr(prhs[2]);
ncols = mxGetN(prhs[2]);
/* Creating link for the scalar output */
plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL);
ans = mxGetPr(plhs[0]);
sub_imed(r,base,diff,(mwSize)ncols,ans);
}
有关问题和下划线算法的更多详细信息,请按照主题Euclidean distance between images。
进行操作我对我的MATLAB脚本进行了剖析,并了解它需要63秒。只对387次调用sub_imed()mex函数。因此,对于1638400次sub_imed的调用,理想情况下需要大约74小时,这太长了。
有人可以通过建议一些减少计算时间的替代方法来帮助我优化代码。
提前致谢。
答案 0 :(得分:3)
我将您的代码移植回MATLAB并做了一些小调整,而结果应该保持不变。我介绍了以下常量:
N = 8192;
step = 0.005;
请注意N / step = 1638400
。这样,您就可以重写变量k
(并将其重命名为baseDiv
):
baseDiv = 1 + (0 : step : (N-step)).';
即。它是1:8193
,步长为0.005
。
同样,l
为1:200
(=1:(1/0.005)
),连续重复8192次,即(现称为baseMod
):
baseMod = (repmat(1:1:(1/step), 1, N)).';
您的变量k1
和l1
只是i
和k
的{{1}}元素,即l
和{{1} }。
使用向量baseDiv(i)
和baseMod(i)
,可以使用
baseDiv
,baseMod
和临时变量d
g
我们可以把它放到你的MATLAB for循环中,这样整个程序就变成了
tmp
通过消除内部for循环并以矢量化方式计算它,1000次迭代仍然需要11秒,因此总运行时间为5小时。仍然 - 加速超过 10x 。为了获得更高的加速速度,您有两种可能性:
1)完成矢量化:您可以使用d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2);
g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2));
tmp = sum(diffVec .* g);
和% Constants
N = 8192;
step = 0.005;
% Some example data
r = 2;
diffVec = rand(N/step,1);
base = (0:(numel(diffVec)-1)).';
baseDiv = (1:step:1+N-step).';
baseMod = (repmat(1:1:(1/step), 1, N)).';
res = 0;
for k=1:(N/step)
d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2);
g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2));
tmp = sum(diffVec .* g);
res = res + tmp * diffVec(k);
end
在列上轻松地对剩余的for循环进行矢量化,以计算 all 值同时。不幸的是我们遇到了一个小问题: 1638400-by-1638400 双矩阵将占用20TB的RAM,我认为 - 你的笔记本电脑没有; - )
2)更少的样本:您正在进行一些分辨率为bsxfun(@minus, baseDiv, baseDiv.')
的数学变换。检查你是否真的,真的需要这种精确度!如果你取精度的1/10:sum
,你的速度要快100倍,并且在 3分钟内完成!
答案 1 :(得分:0)
pow(x,2)
)x*x
mexFunction
diff
中做了什么。如果我不得不猜测我说你在mex -O myfile.cpp
中无意义地复制了内存,但我们需要确保完整的Mex功能。使用void sub_imed( double r, size_t base, const double *diff, size_t dim, double& ans)
{
double d, g;
// these need to be double to avoid underflow
double k = base / 200;
double l = base % 200;
r = 2*r*r;
for(; dim; --dim, ++diff )
{
d = k - i/200;
g = l - i%200;
ans += (*diff) * exp( - (d*d + g*g)/r ) / (pi*r);
}
}
void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[])
{
/* Checking for proper number of arguments */
if(nrhs!=3)
mexErrMsgTxt("Error..Three inputs required.");
if(nlhs!=1)
mexErrMsgTxt("Error..Only one output required.");
/* make sure the first input argument(value of r) is scalar */
if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1 )
mexErrMsgTxt("Error..Value of r must be a scalar.");
/* make sure that the input value of base is a scalar */
if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1 )
mexErrMsgTxt("Error..Value of base must be a scalar.");
/* make sure that the input vector diff is of type double */
if( !mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]) )
mexErrMsgTxt("Error..Input vector must be of type double.");
/* check that number of rows in input arguments is 1 */
if( mxGetM(prhs[2])!=1 )
mexErrMsgTxt("Error..Inputs must be row vectors.");
/* Get the value of r */
double r = mxGetScalar(prhs[0]);
size_t base = static_cast<size_t>(mxGetScalar(prhs[1]);
/* Getting the input vectors */
const double *diff = mxGetPr(prhs[2]);
size_t nrows = static_cast<size_t>(mxGetN(prhs[2]));
/* Creating link for the scalar output */
plhs[0] = mxCreateDoubleScalar(0.0);
sub_imed( r, base, diff, nrows, *mxGetPr(plhs[0]) );
}
尝试以下C ++代码:
{{1}}