如何使用矢量化加速以下MATLAB代码?现在,循环中的单行需要花费数小时来运行案例upper = 1e7
。
以下是带有示例输出的注释代码:
p = 8;
lower = 1;
upper = 1e1;
n = setdiff(lower:upper,primes(upper)); % contains composite numbers between lower + upper
x = ones(length(n),p); % Preallocated 2-D array of ones
% This loop stores the unique prime factors of each composite
% number from 1 to n, in each row of x. Since the rows will have
% varying lengths, the rows are padded with ones at the end.
for i = 1:length(n)
x(i,:) = [unique(factor(n(i))) ones(1,p-length(unique(factor(n(i)))))];
end
输出
x =
1 1 1 1 1 1 1 1
2 1 1 1 1 1 1 1
2 3 1 1 1 1 1 1
2 1 1 1 1 1 1 1
3 1 1 1 1 1 1 1
2 5 1 1 1 1 1 1
例如,如果忽略那些,则最后一行包含素数因子10。我已经将矩阵8列宽,以解释数字高达1000万的许多素因子。
感谢您的帮助!
答案 0 :(得分:3)
这不是矢量化,但是这个版本的循环将节省大约一半的时间:
for k = 1:numel(n)
tmp = unique(factor(n(k)));
x(k,1:numel(tmp)) = tmp;
end
以下是此快速基准:
function t = getPrimeTime
lower = 1;
upper = 2.^(1:8);
t = zeros(numel(upper),2);
for k = 1:numel(upper)
n = setdiff(lower:upper(k),primes(upper(k))); % contains composite numbers between lower to upper
t(k,1) = timeit(@() getPrime1(n));
t(k,2) = timeit(@() getPrime2(n));
disp(k)
end
p = plot(log2(upper),log10(t));
p(1).Marker = 'o';
p(2).Marker = '*';
xlabel('log_2(range of numbers)')
ylabel('log(time (sec))')
legend({'getPrime1','getPrime2'})
end
function x = getPrime1(n) % the originel function
p = 8;
x = ones(length(n),p); % Preallocated 2-D array of ones
for k = 1:length(n)
x(k,:) = [unique(factor(n(k))) ones(1,p-length(unique(factor(n(k)))))];
end
end
function x = getPrime2(n)
p = 8;
x = ones(numel(n),p); % Preallocated 2-D array of ones
for k = 1:numel(n)
tmp = unique(factor(n(k)));
x(k,1:numel(tmp)) = tmp;
end
end
答案 1 :(得分:2)
这是另一种方法:
p = 8;
lower = 1;
upper = 1e1;
p = 8;
q = primes(upper);
n = setdiff(lower:upper, q);
x = bsxfun(@times, q, ~bsxfun(@mod, n(:), q));
x(~x) = inf;
x = sort(x,2);
x(isinf(x)) = 1;
x = [x ones(size(x,1), p-size(x,2))];
这似乎比其他两个选项更快(但使用更多内存)。借用EBH's benchmarking code:
function t = getPrimeTime
lower = 1;
upper = 2.^(1:12);
t = zeros(numel(upper),3);
for k = 1:numel(upper)
n = setdiff(lower:upper(k),primes(upper(k)));
t(k,1) = timeit(@() getPrime1(n));
t(k,2) = timeit(@() getPrime2(n));
t(k,3) = timeit(@() getPrime3(n));
disp(k)
end
p = plot(log2(upper),log10(t));
p(1).Marker = 'o';
p(2).Marker = '*';
p(3).Marker = '^';
xlabel('log_2(range of numbers)')
ylabel('log(time (sec))')
legend({'getPrime1','getPrime2','getPrime3'})
grid on
end
function x = getPrime1(n) % the originel function
p = 8;
x = ones(length(n),p); % Preallocated 2-D array of ones
for k = 1:length(n)
x(k,:) = [unique(factor(n(k))) ones(1,p-length(unique(factor(n(k)))))];
end
end
function x = getPrime2(n)
p = 8;
x = ones(numel(n),p); % Preallocated 2-D array of ones
for k = 1:numel(n)
tmp = unique(factor(n(k)));
x(k,1:numel(tmp)) = tmp;
end
end
function x = getPrime3(n) % Approach in this answer
p = 8;
q = primes(max(n));
x = bsxfun(@times, q, ~bsxfun(@mod, n(:), q));
x(~x) = inf;
x = sort(x,2);
x(isinf(x)) = 1;
x = [x ones(size(x,1), p-size(x,2))];
end