我将图n x n
图W
描述为其邻接矩阵和每个节点的组标签(整数)的n
向量。
对于每对组,我需要计算组c
中的节点与组d
中的节点之间的链接(边)数。为此我写了一个嵌套的for loop
,但我确信这不是计算代码的最快方法,我在代码中调用mcd
,即计算边数之间的矩阵的矩阵小组c
和d
。
是否可以通过bsxfun
更快地完成此操作?
function mcd = interlinks(W,ci)
%// W is the adjacency matrix of a simple undirected graph
%// ci are the group labels of every node in the graph, can be from 1 to |C|
n = length(W); %// number of nodes in the graph
m = sum(nonzeros(triu(W))); %// number of edges in the graph
ncomms = length(unique(ci)); %// number of groups of ci
mcd = zeros(ncomms); %// this is the matrix that counts the number of edges between group c and group d, twice the number of it if c==d
for c=1:ncomms
nodesc = find(ci==c); %// nodes in group c
for d=1:ncomms
nodesd = find(ci==d); %// nodes in group d
M = W(nodesc,nodesd); %// submatrix of edges between c and d
mcd(c,d) = sum(sum(M)); %// count of edges between c and d
end
end
%// Divide diagonal half because counted twice
mcd(1:ncomms+1:ncomms*ncomms)=mcd(1:ncomms+1:ncomms*ncomms)/2;
例如,在这里的图片中,邻接矩阵是
W=[0 1 1 0 0 0;
1 0 1 1 0 0;
1 1 0 0 1 1;
0 1 0 0 1 0;
0 0 1 1 0 1;
0 0 1 0 1 0];
组标签向量为ci=[ 1 1 1 2 2 3]
,结果矩阵mcd
为:
mcd=[3 2 1;
2 1 1;
1 1 0];
这意味着例如组1与自身有3个链接,2个链接与组2,1个链接与组3。
答案 0 :(得分:3)
这个怎么样?
C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;
我认为这就是你想要的。
答案 1 :(得分:1)
IIUC并假设unpack-dependencies
为排序数组,看起来您基本上是在进行逐块求和,但是块大小不规则。因此,您可以使用沿行和列使用cumsum
的方法,然后在ci
中的移位位置进行区分,这基本上会给出块顺式求和。
实现看起来像这样 -
ci
答案 2 :(得分:1)
如果您不反对mex功能,可以使用下面的代码。
n = 2000;
n_labels = 800;
W = rand(n, n);
W = W * W' > .5; % generate symmetric adjacency matrix of logicals
Wd = double(W);
ci = floor(rand(n, 1) * n_labels ) + 1; % generate ids from 1 to 251
[C, IA, IC] = unique(ci);
disp(sprintf('base avg fun time = %g ',timeit(@() interlinks(W, IC))));
disp(sprintf('mex avg fun time = %g ',timeit(@() interlink_mex(W, IC))));
%note this function requires symmetric (function from @aarbelle)
disp(sprintf('bsx avg fun time = %g ',timeit(@() interlinks_bsx(Wd, IC'))));
x1 = interlinks(W, IC);
x2 = interlink_mex(W, IC);
x3 = interlinks_bsx(Wd, IC');
disp(sprintf('norm(x1 - x2) = %g', norm(x1 - x2)));
disp(sprintf('norm(x1 - x3) = %g', norm(x1 - x3)));
使用以下设置测试结果:
base avg fun time = 4.94275
mex avg fun time = 0.0373092
bsx avg fun time = 0.126406
norm(x1 - x2) = 0
norm(x1 - x3) = 0
基本上,对于小n_labels
,bsx函数表现非常好,但你可以使它足够大,以便mex函数更快。
将其放入像interlink_mex.cpp
这样的文件中并使用mex interlink_mex.cpp
进行编译。你的机器上需要一个c ++编译器......
#include "mex.h"
#include "matrix.h"
#include <math.h>
// Author: Matthew Gunn
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if(nrhs != 2)
mexErrMsgTxt("Invalid number of inputs. Shoudl be 2 input argument.");
if(nlhs != 1)
mexErrMsgTxt("Invalid number of outputs. Should be 1 output arguments.");
if(!mxIsLogical(prhs[0])) {
mexErrMsgTxt("First argument should be a logical array (i.e. type logical)");
}
if(!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Second argument should be an array of type double");
}
const mxArray *W = prhs[0];
const mxArray *ci = prhs[1];
size_t W_m = mxGetM(W);
size_t W_n = mxGetN(W);
if(W_m != W_n)
mexErrMsgTxt("Rows and columns of W are not equal");
// size_t ci_m = mxGetM(ci);
size_t ci_n = mxGetNumberOfElements(ci);
mxLogical *W_data = mxGetLogicals(W);
// double *W_data = mxGetPr(W);
double *ci_data = mxGetPr(ci);
size_t *ci_data_size_t = (size_t*) mxCalloc(ci_n, sizeof(size_t));
size_t ncomms = 0;
double intpart;
for(size_t i = 0; i < ci_n; i++) {
double x = ci_data[i];
if(x < 1 || x > 65536 || modf(x, &intpart) != 0.0) {
mexErrMsgTxt("Input ci is not all integers from 1 to a maximum value of 65536 (can edit source code to change this)");
}
size_t xx = (size_t) x;
if(xx > ncomms)
ncomms = xx;
ci_data_size_t[i] = xx - 1;
}
mxArray *mcd = mxCreateDoubleMatrix(ncomms, ncomms, mxREAL);
double *mcd_data = mxGetPr(mcd);
for(size_t i = 0; i < W_n; i++) {
size_t ii = ci_data_size_t[i];
for(size_t j = 0; j < W_n; j++) {
size_t jj = ci_data_size_t[j];
mcd_data[ii + jj * ncomms] += (W_data[i + j * W_m] != 0);
}
}
for(size_t i = 0; i < ncomms * ncomms; i+= ncomms + 1) //go along diagonal
mcd_data[i]/=2; //divide by 2
mxFree(ci_data_size_t);
plhs[0] = mcd;
}