Matlab中的矢量化加速了昂贵的循环

时间:2016-11-01 16:48:20

标签: matlab vectorization

如何使用矢量化加速以下MATLAB代码?现在,循环中的单行需要花费数小时来运行案例upper = 1e7

以下是带有示例输出的注释代码:

p = 8;
lower = 1;
upper = 1e1;
n = setdiff(lower:upper,primes(upper)); % contains composite numbers between lower + upper
x = ones(length(n),p); % Preallocated 2-D array of ones

% This loop stores the unique prime factors of each composite
% number from 1 to n, in each row of x. Since the rows will have
% varying lengths, the rows are padded with ones at the end.

for i = 1:length(n)
    x(i,:) = [unique(factor(n(i))) ones(1,p-length(unique(factor(n(i)))))];
end

输出

x =

 1     1     1     1     1     1     1     1
 2     1     1     1     1     1     1     1
 2     3     1     1     1     1     1     1
 2     1     1     1     1     1     1     1
 3     1     1     1     1     1     1     1
 2     5     1     1     1     1     1     1

例如,如果忽略那些,则最后一行包含素数因子10。我已经将矩阵8列宽,以解释数字高达1000万的许多素因子。

感谢您的帮助!

2 个答案:

答案 0 :(得分:3)

这不是矢量化,但是这个版本的循环将节省大约一半的时间:

for k = 1:numel(n)
    tmp = unique(factor(n(k)));
    x(k,1:numel(tmp)) = tmp;
end

以下是此快速基准:

get prime

function t = getPrimeTime
lower = 1;
upper = 2.^(1:8);
t = zeros(numel(upper),2);
for k = 1:numel(upper)
    n = setdiff(lower:upper(k),primes(upper(k))); % contains composite numbers between lower to upper
    t(k,1) = timeit(@() getPrime1(n));
    t(k,2) = timeit(@() getPrime2(n));
    disp(k)
end
p = plot(log2(upper),log10(t));
p(1).Marker = 'o';
p(2).Marker = '*';
xlabel('log_2(range of numbers)')
ylabel('log(time (sec))')
legend({'getPrime1','getPrime2'})
end

function x = getPrime1(n) % the originel function
p = 8;
x = ones(length(n),p); % Preallocated 2-D array of ones
for k = 1:length(n)
    x(k,:) = [unique(factor(n(k))) ones(1,p-length(unique(factor(n(k)))))];
end
end

function x = getPrime2(n)
p = 8;
x = ones(numel(n),p); % Preallocated 2-D array of ones
for k = 1:numel(n)
    tmp = unique(factor(n(k)));
    x(k,1:numel(tmp)) = tmp;
end
end

答案 1 :(得分:2)

这是另一种方法:

p = 8;
lower = 1;
upper = 1e1;
p = 8;
q = primes(upper);
n = setdiff(lower:upper, q);
x = bsxfun(@times, q, ~bsxfun(@mod, n(:), q));
x(~x) = inf;
x = sort(x,2);
x(isinf(x)) = 1;
x = [x ones(size(x,1), p-size(x,2))];

这似乎比其他两个选项更快(但使用更多内存)。借用EBH's benchmarking code

enter image description here

function t = getPrimeTime
lower = 1;
upper = 2.^(1:12);
t = zeros(numel(upper),3);
for k = 1:numel(upper)
    n = setdiff(lower:upper(k),primes(upper(k)));
    t(k,1) = timeit(@() getPrime1(n));
    t(k,2) = timeit(@() getPrime2(n));
    t(k,3) = timeit(@() getPrime3(n));
    disp(k)
end
p = plot(log2(upper),log10(t));
p(1).Marker = 'o';
p(2).Marker = '*';
p(3).Marker = '^';
xlabel('log_2(range of numbers)')
ylabel('log(time (sec))')
legend({'getPrime1','getPrime2','getPrime3'})
grid on
end

function x = getPrime1(n) % the originel function
p = 8;
x = ones(length(n),p); % Preallocated 2-D array of ones
for k = 1:length(n)
    x(k,:) = [unique(factor(n(k))) ones(1,p-length(unique(factor(n(k)))))];
end
end

function x = getPrime2(n)
p = 8;
x = ones(numel(n),p); % Preallocated 2-D array of ones
for k = 1:numel(n)
    tmp = unique(factor(n(k)));
    x(k,1:numel(tmp)) = tmp;
end
end

function x = getPrime3(n) % Approach in this answer
p = 8;
q = primes(max(n));
x = bsxfun(@times, q, ~bsxfun(@mod, n(:), q));
x(~x) = inf;
x = sort(x,2);
x(isinf(x)) = 1;
x = [x ones(size(x,1), p-size(x,2))];
end