Matlab:使用稀疏结构以二进制数加和的快速方法?

时间:2013-11-07 12:06:00

标签: matlab math binary sparse-matrix

大多数答案仅针对已经回答的关于Hamming weights的问题,但忽略了有关find和处理稀疏性的观点。显然,Shai here的答案解决了关于查找的问题 - 但我还无法对其进行验证。我的回答here没有利用其他答案的聪明才智,例如比特转换但是足够好的例子答案。

输入

>> mlf=sparse([],[],[],2^31+1,1);mlf(1)=10;mlf(10)=111;mlf(77)=1010;  
>> transpose(dec2bin(find(mlf)))

ans =

001
000
000
011
001
010
101

目标

1
0
0
2
1
1
2    

使用稀疏结构快速计算二进制数的数量?

7 个答案:

答案 0 :(得分:6)

你可以通过多种方式实现这一目标。我认为最简单的是

% Example data
F = [268469248 285213696 536904704 553649152];

% Solution 1
sum(dec2bin(F)-'0',2)

最快(找到here):

% Solution 2
w = uint32(F');

p1 = uint32(1431655765);
p2 = uint32(858993459);
p3 = uint32(252645135);
p4 = uint32(16711935);
p5 = uint32(65535);

w = bitand(bitshift(w, -1), p1) + bitand(w, p1);
w = bitand(bitshift(w, -2), p2) + bitand(w, p2);
w = bitand(bitshift(w, -4), p3) + bitand(w, p3);
w = bitand(bitshift(w, -8), p4) + bitand(w, p4);
w = bitand(bitshift(w,-16), p5) + bitand(w, p5);

答案 1 :(得分:5)

根据您的评论,您可以使用dec2bin将数字向量转换为二进制字符串表示形式。然后你可以按照以下方式实现你想要的,我用vector [10 11 12]作为例子:

>> sum(dec2bin([10 11 12])=='1',2)

ans =

     2
     3
     2

或等效地,

>> sum(dec2bin([10 11 12])-'0',2)

对于速度,您可以避免像这样dec2bin(使用modulo-2操作,灵感来自dec2bin代码):

>> sum(rem(floor(bsxfun(@times, [10 11 12].', pow2(1-N:0))),2),2)

ans =

     2
     3
     2

其中N是您期望的最大二进制数字数。

答案 2 :(得分:4)

如果你真的想要快,我认为查找表会很方便。您可以简单地映射,0..255它们有多少个。这样做一次,然后你只需要将一个int分解为它的字节,看看表中的总和并添加结果 - 不需要去字符串......


一个例子:

>> LUT = sum(dec2bin(0:255)-'0',2); % construct the look up table (only once)
>> ii = uint32( find( mlf ) ); % get the numbers
>> vals = LUT( mod( ii, 256 ) + 1 ) + ... % lower bytes
          LUT( mod( ii/256, 256 ) + 1 ) + ...
          LUT( mod( ii/65536, 256 ) + 1 ) + ...
          LUT( mod( ii/16777216, 256 ) + 1 );

使用typecast(由Amro建议):

>> vals = sum( reshape(LUT(double(typecast(ii,'uint8'))+1), 4, [] ), 1 )';

运行时间比较

>> ii = uint32(randi(intmax('uint32'),100000,1));
>> tic; vals1 = sum( reshape(LUT(typecast(ii,'uint8')+1), 4, [] ), 1 )'; toc, %//'
>> tic; vals2 = sum(dec2bin(ii)-'0',2); toc
>> dii = double(ii); % type issues
>> tic; vals3 = sum(rem(floor(bsxfun(@times, dii, pow2(1-32:0))),2),2); toc

结果:

Elapsed time is 0.006144 seconds.  <-- this answer
Elapsed time is 0.120216 seconds.  <-- using dec2bin
Elapsed time is 0.118009 seconds.  <-- using rem and bsxfun

答案 3 :(得分:3)

以下是显示@Shai使用查找表的想法的示例:

% build lookup table for 8-bit integers
lut = sum(dec2bin(0:255)-'0', 2);

% get indices
idx = find(mlf);

% break indices into 8-bit integers and apply LUT
nbits = lut(double(typecast(uint32(idx),'uint8')) + 1);

% sum number of bits in each
s = sum(reshape(nbits,4,[]))

如果你的大型索引在32位范围之外,那么你可能不得不切换到uint64


编辑:

以下是使用Java的另一种解决方案:

idx = find(mlf);
s = arrayfun(@java.lang.Integer.bitCount, idx);

编辑#2:

这是另一种作为C ++ MEX功能实现的解决方案。它依赖于std::bitset::count

bitset_count.cpp

#include "mex.h"
#include <bitset>

void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    // validate input/output arguments
    if (nrhs != 1) {
        mexErrMsgTxt("One input argument required.");
    }
    if (!mxIsUint32(prhs[0]) || mxIsComplex(prhs[0]) || mxIsSparse(prhs[0])) {
        mexErrMsgTxt("Input must be a 32-bit integer dense matrix.");
    }
    if (nlhs > 1) {
        mexErrMsgTxt("Too many output arguments.");
    }

    // create output array
    mwSize N = mxGetNumberOfElements(prhs[0]);
    plhs[0] = mxCreateDoubleMatrix(N, 1, mxREAL);

    // get pointers to data
    double *counts = mxGetPr(plhs[0]);
    uint32_T *idx = reinterpret_cast<uint32_T*>(mxGetData(prhs[0]));

    // count bits set for each 32-bit integer number
    for(mwSize i=0; i<N; i++) {
        std::bitset<32> bs(idx[i]);
        counts[i] = bs.count();
    }
}

将上述函数编译为mex -largeArrayDims bitset_count.cpp,然后照常运行:

idx = find(mlf);
s = bitset_count(uint32(idx))

我决定比较到目前为止提到的所有解决方案:

function [t,v] = testBitsetCount()
    % random data (uint32 vector)
    x = randi(intmax('uint32'), [1e5,1], 'uint32');

    % build lookup table (done once)
    LUT = sum(dec2bin(0:255,8)-'0', 2);

    % functions to compare
    f = {
        @() bit_twiddling(x)      % bit twiddling method
        @() lookup_table(x,LUT);  % lookup table method
        @() bitset_count(x);      % MEX-function (std::bitset::count)
        @() dec_to_bin(x);        % dec2bin
        @() java_bitcount(x);     % Java Integer.bitCount
    };

    % compare timings and check results are valid
    t = cellfun(@timeit, f, 'UniformOutput',true);
    v = cellfun(@feval, f, 'UniformOutput',false);
    assert(isequal(v{:}));
end

function s = lookup_table(x,LUT)
    s = sum(reshape(LUT(double(typecast(x,'uint8'))+1),4,[]))';
end

function s = dec_to_bin(x)
    s = sum(dec2bin(x,32)-'0', 2);
end

function s = java_bitcount(x)
    s = arrayfun(@java.lang.Integer.bitCount, x);
end

function s = bit_twiddling(x)
    p1 = uint32(1431655765);
    p2 = uint32(858993459);
    p3 = uint32(252645135);
    p4 = uint32(16711935);
    p5 = uint32(65535);

    s = x;
    s = bitand(bitshift(s, -1), p1) + bitand(s, p1);
    s = bitand(bitshift(s, -2), p2) + bitand(s, p2);
    s = bitand(bitshift(s, -4), p3) + bitand(s, p3);
    s = bitand(bitshift(s, -8), p4) + bitand(s, p4);
    s = bitand(bitshift(s,-16), p5) + bitand(s, p5);
end

以秒为单位的时间:

t = 
    0.0009    % bit twiddling method
    0.0087    % lookup table method
    0.0134    % C++ std::bitset::count
    0.1946    % MATLAB dec2bin
    0.2343    % Java Integer.bitCount

答案 4 :(得分:0)

这为您提供了稀疏结构中二进制数的rowums。

>> mlf=sparse([],[],[],2^31+1,1);mlf(1)=10;mlf(10)=111;mlf(77)=1010;  
>> transpose(dec2bin(find(mlf)))

ans =

001
000
000
011
001
010
101

>> sum(ismember(transpose(dec2bin(find(mlf))),'1'),2)

ans =

     1
     0
     0
     2
     1
     1
     2

希望有人能够找到更快的rowummation!

答案 5 :(得分:0)

Mex it!


将此代码保存为countTransBits.cpp

#include "mex.h"

void mexFunction( int nout, mxArray* pout[], int nin, mxArray* pin[] ) {
mxAssert( nin == 1 && mxIsSparse(pin[0]) && mxGetN( pin[0] ) == 1,
            "expecting single sparse column vector input" );
    mxAssert( nout == 1, "expecting single output" );

    // set output, assuming 32 bits, set to 64 if needed
    pout[0] = mxCreateNumericMatrix( 32, 1, mxUINT32_CLASS, mxREAL );
    unsigned int* counter = (unsigned int*)mxGetData( pout[0] );
    for ( int i = 0; i < 32; i++ ) {
        counter[i] = 0;
    }

    // start working
    mwIndex *pIr = mxGetIr( pin[0] );
    mwIndex* pJc = mxGetJc( pin[0] );
    double*  pr = mxGetPr( pin[0] );
    for ( mwSize i = pJc[0]; i < pJc[1]; i++ ) {
        if ( pr[i] != 0 ) {// make sure entry is non-zero
            unsigned int entry = pIr[i] + 1; // cast to unsigned int and add 1 for 1-based indexing in Matlab
            int bit = 0;
            while ( entry != 0 && bit < 32 ) {
                counter[bit] += ( entry & 0x1 ); // count the lsb
                bit++;
                entry >>= 1; // shift right
            }
        }
    }
}

在Matlab中编译

>> mex -largeArrayDims -O countTransBits.cpp

运行代码

>> countTransBits( mlf )

请注意,输出计数为32 bins lsb to msb。

答案 6 :(得分:0)

bitcount FEX贡献提供了基于查找表方法的解决方案,但更好地进行了优化。在我的旧笔记本电脑上使用R2015a,它运行速度超过了比特twiddling方法(即Amro报道的最快的纯MATLAB方法)超过100万uint32矢量的两倍。