在将Double *转换为mxArray时使用mexCallMATLAB的最有效方法*

时间:2014-11-06 00:09:43

标签: matlab mex matrix-inverse

我正在编写一个MEX代码,我需要使用pinv函数。我试图找到一种方法,以最有效的方式使用doublepinv类型的数组传递给mexCallMATLAB。让我们举例来说,数组名为G,其大小为100。

double *G = (double*) mxMalloc( 100 * sizeof(double) );

,其中

G[0] = G11; G[1] = G12;
G[2] = G21; G[3] = G22;

这意味着G的每四个连续元素是2×2矩阵。 G存储此25矩阵的2×2个不同值。

enter image description here

我应该注意到这些2×2矩阵没有良好的条件,它们的元素中可能包含全零。如何使用pinv函数计算G元素中的伪逆?例如,如何将数组传递给mexCallMATLAB以计算2×2中第一个G矩阵的伪逆?

我想到了以下方法:

mxArray *G_PINV_input     = mxCreateDoubleMatrix(2, 2, mxREAL);
mxArray *G_PINV_output    = mxCreateDoubleMatrix(2, 2, mxREAL);
double  *G_PINV_input_ptr = mxGetPr(G_PINV_input);

memcpy( G_PINV_input_ptr, &G[0], 4 * sizeof(double));
mexCallMATLAB(1, G_PINV_output, 1, G_PINV_input, "pinv");

我不确定这种方法有多好。复制值是不经济的,因为实际应用中G中的元素总数很大。有没有想跳过这个复制?

2 个答案:

答案 0 :(得分:2)

这是我对MEX功能的实现:

my_pinv.cpp

#include "mex.h"

void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
    // validate arguments
    if (nrhs!=1 || nlhs>1)
        mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
    if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
        mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
    if (mxGetNumberOfElements(prhs[0]) != 100)
        mexErrMsgIdAndTxt("mex:error", "numel() != 100");

    // create necessary arrays
    mxArray *rhs[1], *lhs[1];
    plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
    rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
    double *in = mxGetPr(prhs[0]);
    double *out = mxGetPr(plhs[0]);
    double *x = mxGetPr(rhs[0]), *y;

    // for each 2x2 matrix
    for (mwIndex i=0; i<100; i+=4) {
        // copy 2x2 matrix into rhs
        x[0] = in[i+0];
        x[2] = in[i+1];
        x[1] = in[i+2];
        x[3] = in[i+3];

        // lhs = pinv(rhs)
        mexCallMATLAB(1, lhs, 1, rhs, "pinv");

        // copy 2x2 matrix from lhs
        y = mxGetPr(lhs[0]);
        out[i+0] = y[0];
        out[i+1] = y[1];
        out[i+2] = y[2];
        out[i+3] = y[3];

        // free array
        mxDestroyArray(lhs[0]);
    }

    // cleanup
    mxDestroyArray(rhs[0]);
}

这是MATLAB中的基线实现,以便我们可以验证结果是否正确:

my_pinv0.m

function y = my_pinv0(x)
    y = zeros(size(x));
    for i=1:4:numel(x)
        y(i:i+3) = pinv(x([0 1; 2 3]+i));
    end
end

现在我们测试MEX功能:

% some vector
x = randn(100,1);

% MEX vs. MATLAB function
y = my_pinv0(x);
yy = my_pinv(x);

% compare
assert(isequal(y,yy))

编辑:

这是另一个实现:

my_pinv2.cpp

#include "mex.h"

inline void call_pinv(const double &a, const double &b, const double &c,
                      const double &d, double *out)
{
    mxArray *rhs[1], *lhs[1];

    // create input matrix [a b; c d]
    rhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
    double *x = mxGetPr(rhs[0]);
    x[0] = a;
    x[1] = c;
    x[2] = b;
    x[3] = d;

    // lhs = pinv(rhs)
    mexCallMATLAB(1, lhs, 1, rhs, "pinv");

    // get values from output matrix
    const double *y = mxGetPr(lhs[0]);
    out[0] = y[0];
    out[1] = y[1];
    out[2] = y[2];
    out[3] = y[3];

    // cleanup
    mxDestroyArray(lhs[0]);
    mxDestroyArray(rhs[0]);
}

void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
    // validate arguments
    if (nrhs!=1 || nlhs>1)
        mexErrMsgIdAndTxt("mex:error", "Wrong number of arguments");
    if (!mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0]))
        mexErrMsgIdAndTxt("mex:error", "Input isnt real dense double array");
    if (mxGetNumberOfElements(prhs[0]) != 100)
        mexErrMsgIdAndTxt("mex:error", "numel() != 100");

    // allocate output
    plhs[0] = mxCreateDoubleMatrix(100, 1, mxREAL);
    double *out = mxGetPr(plhs[0]);
    const double *in = mxGetPr(prhs[0]);

    // for each 2x2 matrix
    for (mwIndex i=0; i<100; i+=4) {
        // 2x2 input matrix [a b; c d], and its determinant
        const double a = in[i+0];
        const double b = in[i+1];
        const double c = in[i+2];
        const double d = in[i+3];
        const double det = (a*d - b*c);

        if (det != 0) {
            // inverse of 2x2 matrix [d -b; -c a]/det
            out[i+0] =  d/det;
            out[i+1] = -c/det;
            out[i+2] = -b/det;
            out[i+3] =  a/det;
        }
        else {
            // singular matrix, fallback to pseudo-inverse
            call_pinv(a, b, c, d, &out[i]);
        }
    }
}

这次我们计算2x2矩阵的行列式,如果非零,我们根据下式计算逆:

2x2_matrix_inverse

否则我们回退到从pseudo-inverse的MATLAB调用PINV。

以下是快速基准:

% 100x1 vector
x = randn(100,1);           % average case, with normal 2x2 matrices

% running time
funcs = {@my_pinv0, @my_pinv1, @my_pinv2};
t = cellfun(@(f) timeit(@() f(x)), funcs, 'Uniform',true);

% compare results
y = cellfun(@(f) f(x), funcs, 'Uniform',false);
assert(isequal(y{1},y{2}))

我得到以下时间:

>> fprintf('%.6f\n', t);
0.002111   % MATLAB function
0.001498   % first MEX-file with mexCallMATLAB
0.000010   % second MEX-file with "unrolled" matrix inverse (+ PINV as fallback)

错误是可以接受的并且在机器精度范围内:

>> norm(y{1}-y{3})
ans =
   2.1198e-14

当许多2x2矩阵是奇异的时,你也可以测试最坏的情况:

x = randi([0 1], [100 1]);

答案 1 :(得分:0)

您无需分配输出。只需制作指针,让pinv自动创建mxArray

mxArray *lhs;

然后只需使用&

mexCallMATLAB(1, &lhs, 1, &rhs, "pinv");