改进MATLAB矩阵构造代码:或者,为初学者编写代码矢量化

时间:2013-07-09 05:34:26

标签: matlab matrix vectorization wavelet

我写了一个程序,以构建一个3波段小波变换矩阵的一部分。但是,考虑到矩阵的大小为3 ^ 9 X 3 ^ 10,MATLAB完成构建需要一段时间。因此,我想知道是否有办法改进我用来使其运行得更快的代码。我在运行代码时使用n = 10。

B=zeros(3^(n-1),3^n);
v=[-0.117377016134830 0.54433105395181 -0.0187057473531300 -0.699119564792890 -0.136082763487960 0.426954037816980 ];

for j=1:3^(n-1)-1 
    for k=1:3^n;
        if k>6+3*(j-1) || k<=3*(j-1)
            B(j,k)=0;
        else 
            B(j,k)=v(k-3*(j-1));
        end                
    end
end
j=3^(n-1);
    for k=1:3^n
        if k<=3
            B(j,k)=v(k+3);
        elseif k<=3^n-3
            B(j,k)=0;
        else 
            B(j,k)=v(k-3*(j-1));
        end
    end

W=B;

3 个答案:

答案 0 :(得分:11)

如何在不知道如何进行矢量化的情况下进行矢量化:

首先,我只讨论向量化第一个双循环,你可以按照相同的逻辑进行第二个循环。

我试图从头开始展示一个思考过程,所以虽然最终的答案只有2行,但是值得看看初学者如何尝试获得它。

首先,我建议&#34;按摩&#34;在简单的情况下,代码,以了解它。例如,我使用了n=3v=1:6并运行了第一个循环,这就是B的样子:

[N M]=size(B)
N =
     9
M =
    27

imagesc(B); 

enter image description here

所以你可以看到我们得到一个像矩阵一样的楼梯,这很规律!我们所需要的只是将正确的矩阵索引分配给v及其正确的值。

有很多方法可以实现这一目标,有些方法比其他方法更优雅。最简单的方法之一是使用函数find

pos=[find(B==v(1)),find(B==v(2)),find(B==v(3)),...
     find(B==v(4)),find(B==v(5)),find(B==v(6))]

pos =
     1    10    19    28    37    46
    29    38    47    56    65    74
    57    66    75    84    93   102
    85    94   103   112   121   130
   113   122   131   140   149   158
   141   150   159   168   177   186
   169   178   187   196   205   214
   197   206   215   224   233   242

以上值是矩阵B linear indices ,其中找到v的值。每列代表v中特定值B的{​​{3}}。例如,索引[1 29 57 ...]都包含值v(1)等...每行完全包含v,因此索引[29 38 47 56 65 74]包含v=[v(1) v(2) ... v(6)] 。您可以注意到,对于每一行,索引之间的差异为9,或者,每个索引用步长N分隔,并且其中有6个,这只是向量v的长度(也是由numel(v)获得的。对于列,相邻元素之间的差异为28,或者步长为M+1

我们只需要根据这个逻辑在适当的索引中分配v的值。一种方法是写每一行&#34;行&#34;:

B([1:N:numel(v)*N]+(M+1)*0)=v;
B([1:N:numel(v)*N]+(M+1)*1)=v;
...
B([1:N:numel(v)*N]+(M+1)*(N-2))=v;

但是对于大N-2来说这是不切实际的,所以如果你真的想要你可以在for循环中做到这一点:

for kk=0:N-2;
     B([1:N:numel(v)*N]+(M+1)*kk)=v;
end

Matlab提供了一种使用bsxfun一次获取所有索引的更有效方法(这取代了for循环),例如:

ind=bsxfun(@plus,1:N:N*numel(v),[0:(M+1):M*(N-1)+1]')

现在我们可以使用indv分配到矩阵N-1次。为此我们需要&#34;展平&#34; ind成行向量:

ind=reshape(ind.',1,[]);

并将v连接到自身N-1次(或使{1}}的N-1个更多副本:

v

最后我们得到答案:

vec=repmat(v,[1 N-1]);

长话短说,并且紧凑地编写,我们得到一个2行解决方案(给定大小B(ind)=vec; 已知:B):


[N M]=size(B)

对于ind=bsxfun(@plus,1:N:N*numel(v),[0:(M+1):M*(N-1)+1]'); B(reshape(ind.',1,[]))=repmat(v,[1 N-1]); ,矢量化代码在我的机器中快〜850。 (小n=9将不太重要)

由于获得的矩阵主要由零组成,因此您不需要将这些矩阵分配给完整矩阵,而是使用稀疏矩阵,此处& #39; s的完整代码(非常相似):

n

对于N=3^(n-1); M=3^n; S=sparse([],[],[],N,M); ind=bsxfun(@plus,1:N:N*numel(v),[0:(M+1):M*(N-1)+1]'); S(reshape(ind.',1,[]))=repmat(v,[1 N-1]); ,我只能运行稀疏矩阵代码(否则会内存不足),而在我的机器中则需要约6秒钟。

现在尝试将其应用于第二个循环...

答案 1 :(得分:8)

虽然你的矩阵具有巨大的尺寸,但它也非常“稀疏”,这意味着它的大多数元素都是零。为了提高性能,您可以使用MATLAB的稀疏矩阵支持,确保您只对矩阵的非零部分进行操作。

通过构造稀疏矩阵的coordinate form,可以有效地构建MATLAB中的稀疏矩阵。这意味着必须定义三个数组,用于定义矩阵中每个非零项的行,列和值。这意味着我们不是通过传统的A(i,j) = x语法分配值,而是将非零条目附加到我们的稀疏索引结构上:

row(pos+1) = i;
col(pos+1) = j;
val(pos+1) = x;
% pos is the current position within the sparse indexing arrays!

一旦我们在稀疏索引数组中拥有完整的非零值,我们就可以使用sparse命令来构建矩阵。

对于这个问题,我们为每一行添加最多六个非零条目,允许我们提前分配稀疏索引数组。变量pos跟踪索引数组中的当前位置。

rows = 3^(n-1);
cols = 3^(n+0);

% setup the sparse indexing arrays for non-
% zero elements of matrix B
row = zeros(rows*6,1);
col = zeros(rows*6,1);
val = zeros(rows*6,1);
pos = +0;

我们现在可以通过向稀疏索引数组添加任何非零条目来构建矩阵。由于我们只关心非零条目,我们也只循环遍历矩阵的非零部分。

我已经离开了最后一行的逻辑供您填写!

for j = 1 : 3^(n-1)
    if (j < 3^(n-1))

% add entries for a general row
    for k = max(1,3*(j-1)+1) : min(3^n,3*(j-1)+6)             
        pos = pos+1;
        row(pos) = j;
        col(pos) = k;
        val(pos) = v(k-3*(j-1));                
    end

    else

% add entries for final row - todo!!

    end
end

由于我们没有为每一行添加六个非零,我们可能会过度分配稀疏索引数组,因此我们将它们减少到实际使用的大小。

% only keep the sparse indexing that we've used
row = row(1:pos);
col = col(1:pos);
val = val(1:pos);

现在可以使用sparse命令构建最终矩阵。

% build the actual sparse matrix
B = sparse(row,col,val,rows,cols);

可以通过整理上面的代码段来运行代码。对于n = 9,我们得到以下结果(为了进行比较,我还包括bsxfun建议的natan方法的结果:

Elapsed time is 2.770617 seconds. (original)
Elapsed time is 0.006305 seconds. (sparse indexing)
Elapsed time is 0.261078 seconds. (bsxfun)

原始代码耗尽了n = 10的内存,尽管两种稀疏方法仍然可用:

Elapsed time is 0.019846 seconds. (sparse indexing)
Elapsed time is 2.133946 seconds. (bsxfun)

答案 2 :(得分:2)

您可以使用狡猾的方式创建块对角矩阵,如下所示:

&GT;&GT; v = [ - 0.117377016134830 0.54433105395181 -0.0187057473531300 ...
-0.699119564792890 -0.136082763487960 0.426954037816980];
&GT;&GT; lendiff =长度(V)-3;
&GT;&GT; B = repmat([v zeros(1,3 ^ n-lendiff)],3 ^(n-1),1);
&GT;&GT; B =重塑(B”,3 ^(n)的,3 ^(N-1)+1);
&GT;&GT; B(:,端-1)= B(:,端-1)+ B(:,端);
&GT;&GT; B = B(:,1:端-1)';

此处,lendiff用于创建3 ^ {n-1}个副本,其中v后跟零,其长度为3 ^ n + 3,因此大小为矩阵[ 3 ^ {n-1} 3 ^ n + 3]。

将该矩阵重新整形为大小[3 ^ n 3 ^ {n-1} +1]以创建移位。额外的列需要添加到最后一列,B需要转置。

虽然应该快得多。

修改

看到Darren的解决方案并且意识到reshape也在稀疏矩阵上工作,让我想出这个 - 没有for循环(未编码原始解决方案< / em>的)。

首先是以:

开头的值
>> v=[-0.117377016134830  ...
       0.54433105395181   ...
      -0.0187057473531300 ...
      -0.699119564792890  ...
      -0.136082763487960  ...
       0.426954037816980];    
>> rows = 3^(n-1);                  % same number of rows
>> cols = 3^(n)+3;                  % add 3 cols to implement the shifts    

然后制作每行3个额外列的矩阵

>> row=(1:rows)'*ones(1,length(v)); % row number where each copy of v is stored'
>> col=ones(rows,1)*(1:length(v));  % place v at the start columns of each row
>> val=ones(rows,1)*v;              % fill in the values of v at those positions
>> B=sparse(row,col,val,rows,cols); % make the matrix B[rows cols+3], but now sparse

然后重塑以实现轮班(额外的行,正确的列数)

>> B=reshape(B',3^(n),rows+1);      % reshape into B[3^n rows+1], shifted v per row'
>> B(1:3,end-1)=B(1:3,end);         % the extra column contains last 3 values of v
>> B=B(:,1:end-1)';                 % delete extra column after copying, transpose

对于n = 4,5,6,7,这会导致 s 中的cpu时间:

n    original    new version
4    0.033       0.000
5    0.206       0.000
6    1.906       0.000
7    16.311      0.000

由剖析器测量。对于原始版本,我无法运行n&gt; 7但新版本提供

n    new version
8    0.002
9    0.009
10   0.022
11   0.062
12   0.187
13   0.540
14   1.529
15   4.210

那就是我的RAM走了多远:)。