优化"掩码"在Matlab中的功能

时间:2015-02-21 10:49:12

标签: python matlab optimization numpy benchmarking

对于基准比较,我考虑简单的功能:

function dealiasing2d(where_dealiased, data)
[n1, n0, nk] = size(data);
for i0=1:n0 
    for i1=1:n1
        if where_dealiased(i1, i0)
            data(i1, i0, :) = 0.;
        end
    end
end

它在伪谱模拟中很有用(其中data是一个复数的3d数组),但基本上它将一个掩码应用于一组图像,将一些元素置为零{{1}是的。

我在这个简单的案例中比较了不同语言(以及实现,编译器......)的性能。对于Matlab,我使用timeit计算函数。由于我不想在Matlab中对我的无知进行基准测试,我想用这种语言真正优化这个功能。在Matlab中最快的方法是什么?

我现在使用的简单解决方案是:

where_dealiased

我怀疑这不是正确的方法,因为等效的Numpy解决方案快了大约10倍。

function dealiasing2d(where_dealiased, data)
[n1, n0, nk] = size(data);
N = n0*n1;
ind_zeros = find(reshape(where_dealiased, 1, []));
for ik=1:nk
    data(ind_zeros + N*(ik-1)) = 0;
end

最后,如果有人告诉我类似于"当你想要在Matlab中使用特定函数非常快时,你应该像这样编译它:[...]",我会包括这样的基准测试中的解决方案。


编辑:

在2个答案之后,我对这些命题进行了基准测试,似乎没有明显的性能提升。这很奇怪,因为简单的Python-Numpy解决方案真的(一个数量级)要快得多,所以我仍然在寻找一个更好的Matlab解决方案......

3 个答案:

答案 0 :(得分:2)

如果我理解正确,可以使用bsxfun

轻松快速地完成此操作
data = bsxfun(@times, data, ~where_dealiased);

这会将0where_dealiased true的所有第三维组成部分设置为0(它将它们乘以1),其余部分保留为他们是(它乘以[size(data,1) size(data,2]==size(where_dealiased))。

当然,这假定为reshape


linear indexing的解决方案可能也非常快。为了节省一些时间,您可以删除find,因为ind_zeros = find(where_dealiased); 已经返回线性索引:

{{1}}

答案 1 :(得分:1)

方法#1:Logical indexing repmat -

data(repmat(where_dealiased,1,1,size(data,3))) = 0;

方法#2:Linear indexingbsxfun(@plus -

[m,n,r] =  size(data);
idx = bsxfun(@plus,find(where_dealiased),[0:r-1]*m*n); %// linear indices
data(idx) = 0;

如果where_dealiased中的非零元素很少,那么这应该很快。

答案 2 :(得分:1)

没有基准测试就没有优化!所以这里有一些提出的解决方案和性能测量。初始化代码是:

N = 2000;
nk = 10;

where = false([N, N]);
where(1:100, 1:100) = 1;
data = (5.+j)*ones([N, N, nk]);

我使用函数timeit对函数进行计时:

timeit(@() dealiasing2d(where, data))

为了进行比较,当我对问题中给出的Numpy函数执行完全相同的操作时,它的运行时间为0.0167秒。

具有2个循环的初始Matlab函数在大约0.34秒内运行,并且等效的Numpy函数(具有2个循环)较慢并且在0.42秒内运行。这可能是因为Matlab使用JIT编译。

Luis Mendo mentions我可以移除reshape因为find已经返回线性索引。我喜欢它,因为代码更清晰但是reshape无论如何非常便宜,所以它并没有真正提高函数的性能:

function dealiasing2d(where, data)
[n1, n0, nk] = size(data);
N = n0*n1;
ind_zeros = find(where);
for ik=1:nk
    data(ind_zeros + N*(ik-1)) = 0;
end

此功能需要0.23秒,这比使用2个循环的解决方案快,但与Numpy解决方案相比确实很慢(大约14倍!)。这就是我写我的问题的原因。

Luis Mendo also proposes基于函数bsxfun的解决方案,它提供:

function dealiasing2d_bsxfun(where, data)
data = bsxfun(@times, data, ~where);

此解决方案涉及N*N*nk次乘法(乘以1或0),这显然太多工作,因为我们只需将数组100*100*nk中的值data置为零。然而,这些乘法可以被矢量化,因此它非常快速。与其他Matlab解决方案相比:0.23 s,即与使用find的第一个解决方案相同!

Both solutions proposed by Divakar涉及创建大量N*N*nk的大型数组。没有Matlab循环所以我们希望有更好的表现,但是......

function dealiasing2d_bsxfun2(where, data)
[n1, n0, nk] = size(data);
idx = bsxfun(@plus, find(where), [0:nk-1]*n1*n0);
data(idx) = 0;

需要0.23秒(与其他功能的时间相同!)和

function dealiasing2d(where, data)
data(repmat(where,[1,1,size(data,3)])) = 0;

需要0.30秒(比其他Matlab解决方案多20%)。

总而言之,在这种情况下似乎存在限制Matlab性能的因素。它也可能是在Matlab中有一个更好的解决方案,或者我在基准测试中做错了...如果有人使用Matlab和Python-Numpy可以提供其他时间,那就太好了。


修改

有关Divakar评论的更多数据:

对于N = 500; nk = 500:

Method          | time (s) | normalized      
----------------|----------|------------
Numpy           |    0.05  |     1.0
Numpy loop      |    0.05  |     1.0
Matlab bsxfun   |    0.70  |    14.0
Matlab find     |    0.75  |    15.0
Matlab bsxfun2  |    0.76  |    15.2
Matlab loop     |    0.77  |    15.4
Matlab repmat   |    0.96  |    19.2

对于N = 500; nk = 100:

Method          | time (s) | normalized      
----------------|----------|------------
Numpy           |    0.01  |     1.0
Numpy loop      |    0.03  |     3.0
Matlab bsxfun   |    0.14  |    12.7
Matlab find     |    0.15  |    13.6
Matlab bsxfun2  |    0.16  |    14.5
Matlab loop     |    0.16  |    14.5
Matlab repmat   |    0.20  |    18.2

对于N = 2000; nk = 10:

Method          | time (s) | normalized |     
----------------|----------|------------|
Numpy           |    0.02  |     1.0    |
Matlab find     |    0.23  |    13.8    |
Matlab bsxfun2  |    0.23  |    13.8    |
Matlab bsxfun   |    0.24  |    14.4    |
Matlab repmat   |    0.30  |    18.0    |
Matlab loop     |    0.34  |    20.4    |
Numpy loop      |    0.42  |    25.1    |

我真的很想知道为什么Matlab与Numpy相比显得那么慢......