通过多个向量快速乘法多个矩阵

时间:2017-09-15 20:47:10

标签: matlab performance matrix matrix-multiplication

在matlab中,我想使用L矩阵乘以M个向量,得到M×L个新向量。具体来说,假设我有一个大小为N x M的矩阵A和一个大小为N x N x L矩阵的矩阵B,我想计算一个大小为N x M x L的矩阵C,其中结果与下面的结果完全相同代码:

for m=1:M
    for l=1:L
         C(:,m,l)=B(:,:,l)*A(:,m)
    end
end

但是要有效地实现这一点(使用本机代码而不是matlab循环)。

2 个答案:

答案 0 :(得分:5)

我们可以在这里 ab-use fast matrix-multiplication,只需要重新排列尺寸。因此,将B的第二个维度推回到最后并重新整形为2D,以便合并前两个dims。使用A执行矩阵乘法,为我们提供2D数组。我们称之为C。现在,C's首先昏暗的是来自B的合并后的昏暗。因此,将其拆分为原始的两个暗淡长度,并重新整形,从而产生3D阵列。最后再用另外一个permute将第二个暗点推回到后面。这是期望的3D输出。

因此,实施将是 -

permute(reshape(reshape(permute(B,[1,3,2]),[],N)*A,N,L,[]),[1,3,2])

基准

基准代码:

% Setup inputs
M = 150;
L = 150;
N = 150;
A = randn(N,M);
B = randn(N,N,L);

disp('----------------------- ORIGINAL LOOPY -------------------')
tic
C_loop = NaN(N,M,L);
for m=1:M
    for l=1:L
         C_loop(:,m,l)=B(:,:,l)*A(:,m);
    end
end
toc

disp('----------------------- BSXFUN + PERMUTE -----------------')
% @Luis's soln
tic
C = permute(sum(bsxfun(@times, permute(B, [1 2 4 3]), ...
                        permute(A, [3 1 2])), 2), [1 3 4 2]);
toc

disp('----------------------- BSXFUN + MATRIX-MULT -------------')
% Propose in this post
tic
out = permute(reshape(reshape(permute(B,[1,3,2]),[],N)*A,N,L,[]),[1,3,2]);
toc

时间:

----------------------- ORIGINAL LOOPY -------------------
Elapsed time is 0.905811 seconds.
----------------------- BSXFUN + PERMUTE -----------------
Elapsed time is 0.883616 seconds.
----------------------- BSXFUN + MATRIX-MULT -------------
Elapsed time is 0.045331 seconds.

答案 1 :(得分:4)

您可以通过维度和单例扩展的一些置换来实现:

C = permute(sum(bsxfun(@times, permute(B, [1 2 4 3]), permute(A, [3 1 2])), 2), [1 3 4 2]);

检查:

% Example inputs:
M = 5;
L = 6;
N = 7;
A = randn(N,M);
B = randn(N,N,L);

% Output with bsxfun and permute:    
C = permute(sum(bsxfun(@times, permute(B, [1 2 4 3]), permute(A, [3 1 2])), 2), [1 3 4 2]);

% Output with loops:
C_loop = NaN(N,M,L);
for m=1:M
    for l=1:L
         C_loop(:,m,l)=B(:,:,l)*A(:,m);
    end
end

% Maximum relative error. Should be 0, or of the order of eps:
max_error = max(reshape(abs(C./C_loop),[],1)-1)