我有一个特殊的算法,其中作为持续步骤之一,我需要执行3-D数组与2-D数组的乘法,使得3-D数组的每个矩阵切片相乘二维数组的列。换句话说,如果A
是N x N x N
矩阵且B
是N x N
矩阵,我需要计算大小为{{C
的矩阵1}} N x N
。
实现这个的简单方法是循环,即
C(:,i) = A(:,:,i)*B(:,i);
但是,循环不是Matlab中最快的,应该避免。我正在寻找更快的方法来做到这一点。现在,我所做的就是使用这个事实(现在Mathjax会很棒!):
[A1 b1,A2 b2,...,AN bN] = [A1,A2,...,AN] * blkdiag(b1,b2,...,bN)
这允许摆脱循环,但是,我们必须创建大小为C = zeros(N,N);
for i = 1:N
C(:,i) = A(:,:,i)*B(:,i);
end
的块对角矩阵。我通过N^2 x N
制作它是有效的,就像这样:
sparse
根据我的基准测试,尽管进行了必要的准备(单独乘法比循环快3到4倍),但这种方法比循环快了大约2倍(对于大N)。
这是我做的快速基准,改变问题大小A_long = reshape(A,N,N^2);
b_cell = mat2cell(B,N,ones(1,N)); % convert matrix to cell array of vectors
b_cell{1} = sparse(b_cell{1}); % make first element sparse, this is enough to trigger blkdiag into sparse mode
B_blk = blkdiag(b_cell{:});
C = A_long*B_blk;
并测量循环的时间和替代方法(有和没有准备步骤)。对于大N
,加速大约是2 ... 2.5。
尽管如此,这看起来非常复杂。是否有更简单或更好的方法来实现这一目标?这看起来像是一个非常通用/标准的问题,所以我可以想象解决方案就在身边,我只是不知道该搜索什么。
P.S。:N
是一个明显的选择,但这里的块对角线已经blkdiag(A1,...,AN)*B
所以我认为它不会比我做的更好。
修改:感谢大家的评论!我在Matlab R2016b上进行了新的基准测试。不幸的是,我在同一台计算机上没有这两个版本,所以我们无法比较绝对数字,但相对比较仍然很有趣,因为它已经改变了一点。这是:
以下是对高N区域的放大:
观察结果:
N^2 x N^2
将R2016b简化为squeeze(sum(bsxfun(@times,A,permute(B,[3,1,2])),2))
。它比高squeeze(sum(A.*permute(B,[3,1,2]),2))
的循环快约1.2 ... 1.4。N
的准备开销似乎可以忽略不计,这使得它总体上比循环快3 ... 4倍。这是一个很好的结果。