批矩阵乘法的优化

时间:2018-08-06 14:26:07

标签: matlab matrix-multiplication

我大致了解以下内容:

A = rand(10, 20, 30);
B = rand(10, 30, 40);

我想获得一个大小为C的矩阵(10, 20, 40),目前使用for循环来实现它:

for i = 1:10
    C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end

我尝试做C = bsxfun(@mtimes, A, B);,但这没用。

最优化的最佳方法是什么?我并不是在寻找易于阅读的代码,只是为了获得我能获得的最优化的东西。

谢谢。

2 个答案:

答案 0 :(得分:2)

在开始之前,重要的是要认识到矩阵乘法是一个非常昂贵的过程。它的渐近复杂度是O(n ^ 3)(O(n ^ 2.8)带有strassens)。这意味着尽管您可能不认为自己在进行大量计算,但实际上甚至发生了数十亿甚至您不知道的计算。因此,由于数量众多,您真正可以做的事情受到限制。

如果您不希望用于循环,则可以使用两种方法在MATLAB中进行批矩阵乘法。

第一个是称为mtimesx的函数。编译后的函数使用稀疏矩阵对过程进行矢量化处理。但是,矩阵必须在前2维中。在代码中将执行此操作。

A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]); % Change the dimensions as mtimesx always multiplies the first 2 dimensions
B = permute(B,[2 3 1]);
C = mtimesx(A,B);
C = permute(C,[3 1 2]);

这通常会更快地完成您的问题所描述的操作。

或者,如果您有GPU,则可以以相同的方式使用pagefun

A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]);
B = permute(B,[2 3 1]);
A = gpuArray(A);
B = gpuArray(B);
C = pagefun(@mtimes,A,B);
C = permute(C,[3 1 2]);

此方法将每个问题发送到GPU的页面上,如果使用单精度,则此方法通常比mtimesx快得多。

我修改了@MarcinKonowalczyk脚本来运行所有示例。如您所见,在这种情况下,mtimesx的执行效果最佳,与其他方法相比有了很大的改进

Results of all methods

此外,此图使用1000个矩阵乘法而不是10个矩阵乘法,在这里我们开始看到GPU优于CPU的优势。

Results with 1000 matrix multiplies

close all; clear;

N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end

%% 1st dimension
% Preallocate C
num_problems = 10;
outer_left = 20;
inner = 30;
outer_right = 40;
A = rand(num_problems, outer_left, inner); B = rand(num_problems, inner, outer_right); C = zeros(num_problems, outer_left, outer_right);

t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
    tic
    for i = 1:num_problems
        C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
    end
    t1(j) = toc;
end

%% 2nd dimension
A = permute(A,[2 1 3]); B = permute(B,[2 1 3]); C = permute(C,[2 1 3]);

t2 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:num_problems
        C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
    end
    t2(j) = toc;
end

%% 3rd dimension
A = permute(A,[1 3 2]); B = permute(B,[1 3 2]); C = permute(C,[1 3 2]);

t3 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:num_problems
        C(:, :, i) = A(:, :, i) * B(:, :, i);
    end
    t3(j) = toc;
end

t4 = zeros(1,N);
for ii = 1:N
    tic
    C = mtimesx(A,B);
    t4(ii) = toc;
end

A = gpuArray(A);
B = gpuArray(B);
t5 = zeros(1,N);
for ii = 1:N
    tic
    C = pagefun(@mtimes,A,B);
    t5(ii) = toc;
end

%% Plot

% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;
t4 = t4(11:end)*1e6; t5 = t5(11:end)*1e6;

x = 1:N-10;
plot(x,t1,x,t2,x,t3,x,t4,x,t5);

grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)','mtimesx','pagefun');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f, t4 = %.0f, t5 = %.0f us',median(t1),median(t2),median(t3),median(t4),median(t5)));

答案 1 :(得分:1)

您可以更改保留大小,以优化内存访问。毕竟,您的矩阵是stored in memory长的一维数组。对它们进行不同的切片可能(并且确实)获得相邻的值,而不是到处乱跳。您的代码为:

A = rand(20, 30, 10);
B = rand(30, 40, 10);
C = zeros(20, 40, 10);

for i = 1:10
    C(:, :, i) = A(:, :, i) * B(:, :, i);
end

请注意,您甚至不需要squeeze,因为Matlab会自动删除尾随的单维尺寸,因此由于减少了函数调用,您可以节省一些常数。

enter image description here 这是我使用的代码:

close all; clear; clc;

N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end

%% 1st dimension
% Preallocate C
A = rand(10, 20, 30); B = rand(10, 30, 40); C = zeros(10, 20, 40);

t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
    tic
    for i = 1:10
        C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
    end
    t1(j) = toc;
end

%% 2nd dimension
A = rand(20, 10, 30); B = rand(30, 10, 40); C = zeros(20, 10, 40);

t2 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:10
        C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
    end
    t2(j) = toc;
end

%% 3rd dimension
A = rand(20, 30, 10); B = rand(30, 40, 10); C = zeros(20, 40, 10);

t3 = zeros(1,N);
for j = 1:N
    tic
    for i = 1:10
        C(:, :, i) = A(:, :, i) * B(:, :, i);
    end
    t3(j) = toc;
end

%% Plot

% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;

x = 1:N-10;
plot(x,t1,x,t2,x,t3);

grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f us',median(t1),median(t2),median(t3)));