我正在编写一个MEX
代码,我需要使用pinv
函数。我试图找到一种方法,以最有效的方式使用double
将pinv
类型的数组传递给mexCallMATLAB
。让我们举例来说,数组名为G
,其大小为100。
double *G = (double*) mxMalloc( 100 * sizeof(double) );
,其中
G[0] = G11; G[1] = G12;
G[2] = G21; G[3] = G22;
这意味着G的每四个连续元素是2×2
矩阵。 G
存储此25
矩阵的2×2
个不同值。
我应该注意到这些2×2
矩阵没有良好的条件,它们的元素中可能包含全零。如何使用pinv
函数计算G
元素中的伪逆?例如,如何将数组传递给mexCallMATLAB
以计算2×2
中第一个G
矩阵的伪逆?
我想到了以下方法:
mxArray *G_PINV_input = mxCreateDoubleMatrix(2, 2, mxREAL);
mxArray *G_PINV_output = mxCreateDoubleMatrix(2, 2, mxREAL);
double *G_PINV_input_ptr = mxGetPr(G_PINV_input);
memcpy( G_PINV_input_ptr, &G[0], 4 * sizeof(double));
mexCallMATLAB(1, G_PINV_output, 1, G_PINV_input, "pinv");
我不确定这种方法有多好。复制值是不经济的,因为实际应用中G
中的元素总数很大。有没有想跳过这个复制?
答案 0 :(得分:2)
这是我对MEX功能的实现:
#include "mex.h"
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// validate arguments
if (nrhs!=1 || nlhs>1)
mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
if (mxGetNumberOfElements(prhs[0]) != 100)
mexErrMsgIdAndTxt("mex:error", "numel() != 100");
// create necessary arrays
mxArray *rhs[1], *lhs[1];
plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
double *in = mxGetPr(prhs[0]);
double *out = mxGetPr(plhs[0]);
double *x = mxGetPr(rhs[0]), *y;
// for each 2x2 matrix
for (mwIndex i=0; i<100; i+=4) {
// copy 2x2 matrix into rhs
x[0] = in[i+0];
x[2] = in[i+1];
x[1] = in[i+2];
x[3] = in[i+3];
// lhs = pinv(rhs)
mexCallMATLAB(1, lhs, 1, rhs, "pinv");
// copy 2x2 matrix from lhs
y = mxGetPr(lhs[0]);
out[i+0] = y[0];
out[i+1] = y[1];
out[i+2] = y[2];
out[i+3] = y[3];
// free array
mxDestroyArray(lhs[0]);
}
// cleanup
mxDestroyArray(rhs[0]);
}
这是MATLAB中的基线实现,以便我们可以验证结果是否正确:
function y = my_pinv0(x)
y = zeros(size(x));
for i=1:4:numel(x)
y(i:i+3) = pinv(x([0 1; 2 3]+i));
end
end
现在我们测试MEX功能:
% some vector
x = randn(100,1);
% MEX vs. MATLAB function
y = my_pinv0(x);
yy = my_pinv(x);
% compare
assert(isequal(y,yy))
这是另一个实现:
#include "mex.h"
inline void call_pinv(const double &a, const double &b, const double &c,
const double &d, double *out)
{
mxArray *rhs[1], *lhs[1];
// create input matrix [a b; c d]
rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
double *x = mxGetPr(rhs[0]);
x[0] = a;
x[1] = c;
x[2] = b;
x[3] = d;
// lhs = pinv(rhs)
mexCallMATLAB(1, lhs, 1, rhs, "pinv");
// get values from output matrix
const double *y = mxGetPr(lhs[0]);
out[0] = y[0];
out[1] = y[1];
out[2] = y[2];
out[3] = y[3];
// cleanup
mxDestroyArray(lhs[0]);
mxDestroyArray(rhs[0]);
}
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
// validate arguments
if (nrhs!=1 || nlhs>1)
mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
if (mxGetNumberOfElements(prhs[0]) != 100)
mexErrMsgIdAndTxt("mex:error", "numel() != 100");
// allocate output
plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
double *out = mxGetPr(plhs[0]);
const double *in = mxGetPr(prhs[0]);
// for each 2x2 matrix
for (mwIndex i=0; i<100; i+=4) {
// 2x2 input matrix [a b; c d], and its determinant
const double a = in[i+0];
const double b = in[i+1];
const double c = in[i+2];
const double d = in[i+3];
const double det = (a*d - b*c);
if (det != 0) {
// inverse of 2x2 matrix [d -b; -c a]/det
out[i+0] = d/det;
out[i+1] = -c/det;
out[i+2] = -b/det;
out[i+3] = a/det;
}
else {
// singular matrix, fallback to pseudo-inverse
call_pinv(a, b, c, d, &out[i]);
}
}
}
这次我们计算2x2矩阵的行列式,如果非零,我们根据下式计算逆:
否则我们回退到从pseudo-inverse的MATLAB调用PINV。
以下是快速基准:
% 100x1 vector
x = randn(100,1); % average case, with normal 2x2 matrices
% running time
funcs = {@my_pinv0, @my_pinv1, @my_pinv2};
t = cellfun(@(f) timeit(@() f(x)), funcs, 'Uniform',true);
% compare results
y = cellfun(@(f) f(x), funcs, 'Uniform',false);
assert(isequal(y{1},y{2}))
我得到以下时间:
>> fprintf('%.6f\n', t);
0.002111 % MATLAB function
0.001498 % first MEX-file with mexCallMATLAB
0.000010 % second MEX-file with "unrolled" matrix inverse (+ PINV as fallback)
错误是可以接受的并且在机器精度范围内:
>> norm(y{1}-y{3})
ans =
2.1198e-14
当许多2x2矩阵是奇异的时,你也可以测试最坏的情况:
x = randi([0 1], [100 1]);
答案 1 :(得分:0)
您无需分配输出。只需制作指针,让pinv
自动创建mxArray
。
mxArray *lhs;
然后只需使用&
,
mexCallMATLAB(1, &lhs, 1, &rhs, "pinv");