两个数组的相应2d切片的乘法和数组切片的反演

时间:2014-10-28 11:20:49

标签: arrays matlab vectorization matrix-multiplication

我有两个相同维度A的数组B1000 x 3 x 20 x 20。我想生成维度C的第三个数组3 x 3 x 20 x 20,它将是AB的相应切片的矩阵乘法的结果,即C(:,:,i,j) = A(:,:,i,j)'*B(:,:,i,j)。然后,我需要通过反转相应的C矩阵(即D)将数组3 x 3转换为新数组D(:,:,i,j) = inv(C(:,:,i,j))。同样,很清楚如何使用循环来完成此操作。有没有办法绕过400个项目?

编辑:用于比较不同解决方案的效果的基准代码将是 -

%// Inputs
n1 = 50;
n2 = 200;
A = rand(n1,3,n2,n2);
B = rand(n1,3,n2,n2);

%// A. CPU loopy code
tic
C = zeros(3,3,n2,n2);
for ii = 1:n2
    for jj = 1:n2
        C(:,:,ii,jj) = A(:,:,ii,jj)'*B(:,:,ii,jj); %//'
    end
end
toc

%// B. Vectorized code (using squeeze)
tic
C1 = squeeze(sum(bsxfun(@times,permute(A,[2 1 5 3 4]),permute(B,[5 1 2 3 4])),2));
toc

%// C. Vectorized code (avoiding squeeze)
tic
C2 = sum(bsxfun(@times,permute(A,[2 5 3 4 1]),permute(B,[5 2 3 4 1])),5);
toc

%// D. GPU vectorized code
tic
A = gpuArray(A);
B = gpuArray(B);
C3 = sum(bsxfun(@times,permute(A,[2 5 3 4 1]),permute(B,[5 2 3 4 1])),5);
C3 = gather(C3);
toc

运行时结果 -

Elapsed time is 0.287511 seconds.
Elapsed time is 0.250663 seconds.
Elapsed time is 0.337628 seconds.
Elapsed time is 1.259207 seconds.

1 个答案:

答案 0 :(得分:1)

<强>代码

%// Part - 1
C = sum(bsxfun(@times,permute(A,[2 5 3 4 1]),permute(B,[5 2 3 4 1])),5);

%// Part - 2: Use MATLAB file-exchange tool multinv
D = multinv(C);

multinv的功能代码可用here,声称效率非常高。

对于第一部分,您也可以试试这个 -

C = squeeze(sum(bsxfun(@times,permute(A,[2 1 5 3 4]),permute(B,[5 1 2 3 4])),2));

这个似乎正在重新安排这些元素,而不是像#34;破坏性地&#34;正如上面代码中提到的那个,但缺点是需要squeeze可能会减慢它的速度。我会留给你,也鼓励你选择更好的基准。


为什么bsxfun + GPU

我增加了循环限制,因为这可能是循环代码和矢量化代码之间的真正测试。所以,这是第1部分的修改代码 -

%// Inputs
n1 = 50;
n2 = 200;
A = rand(n1,3,n2,n2);
B = rand(n1,3,n2,n2);

%// A. CPU loopy code
tic
C = zeros(3,3,n2,n2);
for ii = 1:n2
    for jj = 1:n2
        C(:,:,ii,jj) = A(:,:,ii,jj)'*B(:,:,ii,jj); %//'
    end
end
toc

%// B. GPU vectorized code
tic
A = gpuArray(A);
B = gpuArray(B);
C1 = sum(bsxfun(@times,permute(A,[2 5 3 4 1]),permute(B,[5 2 3 4 1])),5);
C1 = gather(C1);
toc

我系统的运行时结果是 -

Elapsed time is 0.310056 seconds.
Elapsed time is 0.172499 seconds.

所以,你看!