我大致了解以下内容:
A = rand(10, 20, 30);
B = rand(10, 30, 40);
我想获得一个大小为C
的矩阵(10, 20, 40)
,目前使用for循环来实现它:
for i = 1:10
C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end
我尝试做C = bsxfun(@mtimes, A, B);
,但这没用。
最优化的最佳方法是什么?我并不是在寻找易于阅读的代码,只是为了获得我能获得的最优化的东西。
谢谢。
答案 0 :(得分:2)
在开始之前,重要的是要认识到矩阵乘法是一个非常昂贵的过程。它的渐近复杂度是O(n ^ 3)(O(n ^ 2.8)带有strassens)。这意味着尽管您可能不认为自己在进行大量计算,但实际上甚至发生了数十亿甚至您不知道的计算。因此,由于数量众多,您真正可以做的事情受到限制。
如果您不希望用于循环,则可以使用两种方法在MATLAB中进行批矩阵乘法。
第一个是称为mtimesx的函数。编译后的函数使用稀疏矩阵对过程进行矢量化处理。但是,矩阵必须在前2维中。在代码中将执行此操作。
A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]); % Change the dimensions as mtimesx always multiplies the first 2 dimensions
B = permute(B,[2 3 1]);
C = mtimesx(A,B);
C = permute(C,[3 1 2]);
这通常会更快地完成您的问题所描述的操作。
或者,如果您有GPU,则可以以相同的方式使用pagefun。
A = rand(10, 20, 30);
B = rand(10, 30, 40);
A = permute(A,[2 3 1]);
B = permute(B,[2 3 1]);
A = gpuArray(A);
B = gpuArray(B);
C = pagefun(@mtimes,A,B);
C = permute(C,[3 1 2]);
此方法将每个问题发送到GPU的页面上,如果使用单精度,则此方法通常比mtimesx快得多。
我修改了@MarcinKonowalczyk脚本来运行所有示例。如您所见,在这种情况下,mtimesx的执行效果最佳,与其他方法相比有了很大的改进
此外,此图使用1000个矩阵乘法而不是10个矩阵乘法,在这里我们开始看到GPU优于CPU的优势。
close all; clear;
N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end
%% 1st dimension
% Preallocate C
num_problems = 10;
outer_left = 20;
inner = 30;
outer_right = 40;
A = rand(num_problems, outer_left, inner); B = rand(num_problems, inner, outer_right); C = zeros(num_problems, outer_left, outer_right);
t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
tic
for i = 1:num_problems
C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end
t1(j) = toc;
end
%% 2nd dimension
A = permute(A,[2 1 3]); B = permute(B,[2 1 3]); C = permute(C,[2 1 3]);
t2 = zeros(1,N);
for j = 1:N
tic
for i = 1:num_problems
C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
end
t2(j) = toc;
end
%% 3rd dimension
A = permute(A,[1 3 2]); B = permute(B,[1 3 2]); C = permute(C,[1 3 2]);
t3 = zeros(1,N);
for j = 1:N
tic
for i = 1:num_problems
C(:, :, i) = A(:, :, i) * B(:, :, i);
end
t3(j) = toc;
end
t4 = zeros(1,N);
for ii = 1:N
tic
C = mtimesx(A,B);
t4(ii) = toc;
end
A = gpuArray(A);
B = gpuArray(B);
t5 = zeros(1,N);
for ii = 1:N
tic
C = pagefun(@mtimes,A,B);
t5(ii) = toc;
end
%% Plot
% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;
t4 = t4(11:end)*1e6; t5 = t5(11:end)*1e6;
x = 1:N-10;
plot(x,t1,x,t2,x,t3,x,t4,x,t5);
grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)','mtimesx','pagefun');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f, t4 = %.0f, t5 = %.0f us',median(t1),median(t2),median(t3),median(t4),median(t5)));
答案 1 :(得分:1)
您可以更改保留大小,以优化内存访问。毕竟,您的矩阵是stored in memory长的一维数组。对它们进行不同的切片可能(并且确实)获得相邻的值,而不是到处乱跳。您的代码为:
A = rand(20, 30, 10);
B = rand(30, 40, 10);
C = zeros(20, 40, 10);
for i = 1:10
C(:, :, i) = A(:, :, i) * B(:, :, i);
end
请注意,您甚至不需要squeeze
,因为Matlab会自动删除尾随的单维尺寸,因此由于减少了函数调用,您可以节省一些常数。
close all; clear; clc;
N = 1000;
N = N+10; % Add a few initial runs to be trimmed off at the end
%% 1st dimension
% Preallocate C
A = rand(10, 20, 30); B = rand(10, 30, 40); C = zeros(10, 20, 40);
t1 = zeros(1,N); % Preallocate timing results vector
for j = 1:N % Do the multiplication N times
tic
for i = 1:10
C(i, :, :) = squeeze(A(i, :, :)) * squeeze(B(i, :, :));
end
t1(j) = toc;
end
%% 2nd dimension
A = rand(20, 10, 30); B = rand(30, 10, 40); C = zeros(20, 10, 40);
t2 = zeros(1,N);
for j = 1:N
tic
for i = 1:10
C(:, i, :) = squeeze(A(:, i, :)) * squeeze(B(:, i, :));
end
t2(j) = toc;
end
%% 3rd dimension
A = rand(20, 30, 10); B = rand(30, 40, 10); C = zeros(20, 40, 10);
t3 = zeros(1,N);
for j = 1:N
tic
for i = 1:10
C(:, :, i) = A(:, :, i) * B(:, :, i);
end
t3(j) = toc;
end
%% Plot
% Trim initial runs and convert to microsecconds
t1 = t1(11:end)*1e6; t2 = t2(11:end)*1e6; t3 = t3(11:end)*1e6;
x = 1:N-10;
plot(x,t1,x,t2,x,t3);
grid on;
xlabel('trial number');
ylabel('running time / us');
legend('C(i,:,:)','C(:,i,:)','C(:,:,i)');
title(sprintf('t1 = %.0f, t2 = %.0f, t3 = %.0f us',median(t1),median(t2),median(t3)));