我在应用程序中遇到一个小的性能瓶颈,需要从大的方阵中删除非对角元素。所以,矩阵x
17 24 1 8 15
23 5 7 14 16
4 6 13 20 22
10 12 19 21 3
11 18 25 2 9
变为
17 0 0 0 0
0 5 0 0 0
0 0 13 0 0
0 0 0 21 0
0 0 0 0 9
问题:下面的bsxfun和diag解决方案是迄今为止最快的解决方案,我怀疑我可以在保持Matlab代码的同时改进它,但有更快的方法吗?
这是我到目前为止所想到的。
通过单位矩阵执行逐元素乘法。这是最简单的解决方案:
y = x .* eye(n);
使用bsxfun
和diag
:
y = bsxfun(@times, diag(x), eye(n));
下/上三角矩阵:
y = x - tril(x, -1) - triu(x, 1);
使用循环的各种解决方案:
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
和
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
bsxfun
解决方案实际上是最快的。这是我的计时代码:
function timing()
clear all
n = 5000;
x = rand(n, n);
f1 = @() tf1(x, n);
f2 = @() tf2(x, n);
f3 = @() tf3(x);
f4 = @() tf4(x, n);
f5 = @() tf5(x, n);
t1 = timeit(f1);
t2 = timeit(f2);
t3 = timeit(f3);
t4 = timeit(f4);
t5 = timeit(f5);
fprintf('t1: %f s\n', t1)
fprintf('t2: %f s\n', t2)
fprintf('t3: %f s\n', t3)
fprintf('t4: %f s\n', t4)
fprintf('t5: %f s\n', t5)
end
function y = tf1(x, n)
y = x .* eye(n);
end
function y = tf2(x, n)
y = bsxfun(@times, diag(x), eye(n));
end
function y = tf3(x)
y = x - tril(x, -1) - triu(x, 1);
end
function y = tf4(x, n)
y = x;
for ix=1:n
for jx=1:n
if ix ~= jx
y(ix, jx) = 0;
end
end
end
end
function y = tf5(x, n)
y = x;
for ix=1:n
for jx=1:ix-1
y(ix, jx) = 0;
end
for jx=ix+1:n
y(ix, jx) = 0;
end
end
end
返回
t1: 0.111117 s
t2: 0.078692 s
t3: 0.219582 s
t4: 1.183389 s
t5: 1.198795 s
答案 0 :(得分:9)
我发现:
diag(diag(x))
比bsxfun
快。类似地:
diag(x(1:size(x,1)+1:end))
或多或少相同的量更快。为timeit
玩x=rand(5000)
我的速度比bsxfun
快20倍。
修改强>
这与diag(diag(...
:
x2(n,n)=0;
x2(1:n+1:end)=x(1:n+1:end);
请注意,我预先分配x2
的方式非常重要,如果您只使用x2=zeros(n)
,则会得到较慢的解决方案。在this discussion ...
答案 1 :(得分:8)
我没有费心测试你的各种循环函数,因为它们的实现速度要慢得多,但是我测试了其他循环函数,加上我以前用过的另一种方法:
y = diag(diag(x));
这是剧透:
c1: 193.18 milliseconds // multiply by identity
c2: 102.16 milliseconds // bsxfun
c3: 342.24 milliseconds // tril and triu
c4: 6.03 milliseconds // call diag twice
看起来,diag
的两次调用是我机器上最快的。
接下来是完整的时序代码。我使用自己的基准测试功能而不是timeit
,但结果应该是可比较的(您可以自己检查)。
>> x = randn(5000);
>> c1 = @() x .* eye(5000);
>> c2 = @() bsxfun(@times, diag(x), eye(5000));
>> c3 = @() x - tril(x,-1) - triu(x,1);
>> c4 = @() diag(diag(x));
>> benchmark.bench(c1)
Benchmarking @()x.*eye(5000)
Mean: 193.18 milliseconds, lb 191.94 milliseconds, ub 194.25 milliseconds, ci 95%
Stdev: 6.01 milliseconds, lb 3.27 milliseconds, ub 8.58 milliseconds, ci 95%
>> benchmark.bench(c2)
Benchmarking @()bsxfun(@times,diag(x),eye(5000))
Mean: 102.16 milliseconds, lb 100.83 milliseconds, ub 103.44 milliseconds, ci 95%
Stdev: 6.61 milliseconds, lb 6.04 milliseconds, ub 7.07 milliseconds, ci 95%
>> benchmark.bench(c3)
Benchmarking @()x-tril(x,-1)-triu(x,1)
Mean: 342.24 milliseconds, lb 340.28 milliseconds, ub 344.20 milliseconds, ci 95%
Stdev: 10.06 milliseconds, lb 8.85 milliseconds, ub 11.17 milliseconds, ci 95%
>> benchmark.bench(c4)
Benchmarking @()diag(diag(x))
Mean: 6.03 milliseconds, lb 5.96 milliseconds, ub 6.09 milliseconds, ci 95%
Stdev: 0.34 milliseconds, lb 0.27 milliseconds, ub 0.40 milliseconds, ci 95%