使用bsxfun加速Matlab嵌套for循环

时间:2016-04-18 12:43:50

标签: matlab optimization vectorization bsxfun

我将图n x nW描述为其邻接矩阵和每个节点的组标签(整数)的n向量。

对于每对组,我需要计算组c中的节点与组d中的节点之间的链接(边)数。为此我写了一个嵌套的for loop,但我确信这不是计算代码的最快方法,我在代码中调用mcd,即计算边数之间的矩阵的矩阵小组cd。 是否可以通过bsxfun更快地完成此操作?

function mcd = interlinks(W,ci)
%// W is the adjacency matrix of a simple undirected graph
%// ci are the group labels of every node in the graph, can be from 1 to |C|
n = length(W); %// number of nodes in the graph
m = sum(nonzeros(triu(W))); %// number of edges in the graph
ncomms = length(unique(ci)); %// number of groups of ci

mcd = zeros(ncomms); %// this is the matrix that counts the number of edges between group c and group d, twice the number of it if c==d

for c=1:ncomms
    nodesc = find(ci==c); %// nodes in group c
    for d=1:ncomms
        nodesd = find(ci==d); %// nodes in group d
        M = W(nodesc,nodesd); %// submatrix of edges between c and d
        mcd(c,d) = sum(sum(M)); %// count of edges between c and d
    end
end

%// Divide diagonal half because counted twice
mcd(1:ncomms+1:ncomms*ncomms)=mcd(1:ncomms+1:ncomms*ncomms)/2;

例如,在这里的图片中,邻接矩阵是

W=[0 1 1 0 0 0;
   1 0 1 1 0 0;
   1 1 0 0 1 1;
   0 1 0 0 1 0;
   0 0 1 1 0 1;
   0 0 1 0 1 0];

组标签向量为ci=[ 1 1 1 2 2 3],结果矩阵mcd为:

mcd=[3 2 1; 
     2 1 1;
     1 1 0];

这意味着例如组1与自身有3个链接,2个链接与组2,1个链接与组3。

simple community structure

3 个答案:

答案 0 :(得分:3)

这个怎么样?

C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;

我认为这就是你想要的。

答案 1 :(得分:1)

IIUC并假设unpack-dependencies为排序数组,看起来您基本上是在进行逐块求和,但是块大小不规则。因此,您可以使用沿行和列使用cumsum的方法,然后在ci中的移位位置进行区分,这基本上会给出块顺式求和。

实现看起来像这样 -

ci

答案 2 :(得分:1)

如果您不反对mex功能,可以使用下面的代码。

测试代码

n = 2000;
n_labels = 800;
W = rand(n, n);               

W = W * W' > .5;              % generate symmetric adjacency matrix of logicals
Wd = double(W);
ci = floor(rand(n, 1) * n_labels ) + 1; % generate ids from 1 to 251

[C, IA, IC] = unique(ci);

disp(sprintf('base avg fun time = %g ',timeit(@() interlinks(W, IC))));
disp(sprintf('mex avg fun time = %g ',timeit(@() interlink_mex(W, IC))));

%note this function requires symmetric (function from @aarbelle)
disp(sprintf('bsx avg fun time = %g ',timeit(@() interlinks_bsx(Wd, IC'))));

x1 = interlinks(W, IC);
x2 = interlink_mex(W, IC);
x3 = interlinks_bsx(Wd, IC');

disp(sprintf('norm(x1 - x2) = %g', norm(x1 - x2)));
disp(sprintf('norm(x1 - x3) = %g', norm(x1 - x3)));

测试结果

使用以下设置测试结果:

base avg fun time = 4.94275 
mex avg fun time = 0.0373092 
bsx avg fun time = 0.126406 
norm(x1 - x2) = 0
norm(x1 - x3) = 0

基本上,对于小n_labels,bsx函数表现非常好,但你可以使它足够大,以便mex函数更快。

c ++代码

将其放入像interlink_mex.cpp这样的文件中并使用mex interlink_mex.cpp进行编译。你的机器上需要一个c ++编译器......

#include "mex.h"
#include "matrix.h"
#include <math.h>

//  Author: Matthew Gunn

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
  if(nrhs != 2)
    mexErrMsgTxt("Invalid number of inputs.  Shoudl be 2 input argument.");

  if(nlhs != 1)
    mexErrMsgTxt("Invalid number of outputs.  Should be 1 output arguments.");

  if(!mxIsLogical(prhs[0])) {
    mexErrMsgTxt("First argument should be a logical array (i.e. type logical)");
  }
  if(!mxIsDouble(prhs[1])) {
    mexErrMsgTxt("Second argument should be an array of type double");

  }

  const mxArray *W = prhs[0];
  const mxArray *ci = prhs[1];

  size_t W_m = mxGetM(W);
  size_t W_n = mxGetN(W);

  if(W_m != W_n)
    mexErrMsgTxt("Rows and columns of W are not equal");

  //  size_t ci_m = mxGetM(ci);
  size_t ci_n = mxGetNumberOfElements(ci);


  mxLogical *W_data = mxGetLogicals(W);
  //  double *W_data = mxGetPr(W);
  double *ci_data = mxGetPr(ci);

  size_t *ci_data_size_t = (size_t*) mxCalloc(ci_n, sizeof(size_t));
  size_t ncomms = 0;

  double intpart;
  for(size_t i = 0; i < ci_n; i++) {
    double x = ci_data[i];
    if(x < 1 || x > 65536 || modf(x, &intpart) != 0.0) {
       mexErrMsgTxt("Input ci is not all integers from 1 to a maximum value of 65536 (can edit source code to change this)");

     }
    size_t xx = (size_t) x;
    if(xx > ncomms)
      ncomms = xx;
    ci_data_size_t[i] = xx - 1;
  }

  mxArray *mcd = mxCreateDoubleMatrix(ncomms, ncomms, mxREAL);
  double *mcd_data = mxGetPr(mcd);


  for(size_t i = 0; i < W_n; i++) {
    size_t ii = ci_data_size_t[i];
    for(size_t j = 0; j < W_n; j++) {  
      size_t jj = ci_data_size_t[j];
      mcd_data[ii + jj * ncomms] += (W_data[i + j * W_m] != 0);
    }    
  }
  for(size_t i = 0; i < ncomms * ncomms; i+= ncomms + 1) //go along diagonal
    mcd_data[i]/=2; //divide by 2

  mxFree(ci_data_size_t);
  plhs[0] = mcd;
}