加速3D阵列乘法

时间:2016-07-07 15:54:53

标签: matlab multidimensional-array vectorization

我有以下代码片段,它将3D数组的每个2D横截面乘以向量:

A = zeros(N,M);
for k = 1:M
   B = C(:,:,k);
   A(:,k) = B * f(:,k);
end

当我对代码进行分析时,我发现对于N = 200,M = 25,这可能非常慢(相对于我的代码的其他部分);特别是这一行:

B=C(:,:,k)

可占用总运行时间的很大一部分。有什么办法可以加快速度吗?

1 个答案:

答案 0 :(得分:0)

您可以使用bsxfunpermute结合使用单例扩展执行所需的乘法,然后沿适当的维度添加以计算矩阵乘法。

定义示例数据:

N = 5;
M = 4;
C = rand(N,N,M);
f = rand(N,M);

然后结果计算为

result = permute(sum(bsxfun(@times, C, permute(f,[3 1 2])),2), [1 3 2]);

作为检查,与使用循环计算的结果进行比较:

A = zeros(N,M);
for k = 1:M
   B = C(:,:,k);
   A(:,k) = B * f(:,k);
end
result./A

给出

ans =
     1     1     1     1
     1     1     1     1
     1     1     1     1
     1     1     1     1
     1     1     1     1