在for循环中矢量化查找函数

时间:2013-07-03 17:43:32

标签: algorithm matlab vectorization

我有以下代码输出array1的值小于或等于array2的每个元素的值。这两个数组的长度不同。由于数组很大(~500,000元素),因此for循环非常慢。仅供参考,两个阵列总是按升序排列。

任何帮助使这个矢量操作并加快它的速度将不胜感激。

我正在考虑interp1()采用'最近'选项的某种多步骤流程。然后找到相应的outArray大于array2的位置,然后以某种方式固定点......但我认为必须有更好的方法。

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];
outArray = nan(size(array2));
for a =1:numel(array2)
    outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end

返回:

outArray =    
     5     5    15    24

3 个答案:

答案 0 :(得分:3)

这是一种可能的矢量化:

[~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
outArray = array1(idx);

编辑:

在最近的版本中,由于JIT编译,MATLAB已经非常擅长执行良好的旧非矢量化循环。

下面是一些类似于你的代码,它利用了两个数组的排序事实(因此如果pos(a) = find(array1<=array2(a), 1, 'last'),那么我们可以保证在下一次迭代中计算的pos(a+1)不小于上一个pos(a)

pos = 1;
idx = zeros(size(array2));
for a=1:numel(array2)
    while pos <= numel(array1) && array1(pos) <= array2(a)
        pos = pos + 1;
    end
    idx(a) = pos-1;
end
%idx(idx==0) = [];      %# in case min(array2) < min(array1)
outArray = array1(idx);

注意:当array2的最小值小于array1的最小值时(即find(array1<=array2(a))为空时),注释行处理大小写

我对目前发布的所有方法进行了比较,这确实是最快的方法。对于长度为N = 5000的向量,定时(使用TIMEIT函数执行)是:

0.097398     # your code
0.39127      # my first vectorized code
0.00043361   # my new code above
0.0016276    # Mohsen Nosratinia's code

以下是N = 500000的时间安排:

(? too-long) # your code
(out-of-mem) # my first vectorized code
0.051197     # my new code above
0.25206      # Mohsen Nosratinia's code

..从你报告的最初10分钟到0.05秒,这是一个相当不错的改进!

如果您想要重现结果,请输入以下测试代码:

function [t,v] = test_array_find()
    %array2 = [5 6 18 25];
    %array1 = [1 5 9 15 22 24 31];
    N = 5000;
    array1 = sort(randi([100 1e6], [1 N]));
    array2 = sort(randi([min(array1) 1e6], [1 N]));

    f = {...
        @() func1(array1,array2);   %# Aero Engy
        @() func2(array1,array2);   %# Amro
        @() func3(array1,array2);   %# Amro
        @() func4(array1,array2);   %# Mohsen Nosratinia
    };

    t = cellfun(@timeit, f);
    v = cellfun(@feval, f, 'UniformOutput',false);
    assert( isequal(v{:}) )
end

function outArray = func1(array1,array2)
    %idx = arrayfun(@(a) find(array1<=a, 1, 'last'), array2);
    idx = zeros(size(array2));
    for a=1:numel(array2)
        idx(a) = find(array1 <= array2(a), 1, 'last');
    end
    outArray = array1(idx);
end

function outArray = func2(array1,array2)
    [~,idx] = max(cumsum(bsxfun(@le, array1', array2)));
    outArray = array1(idx);
end

function outArray = func3(array1,array2)
    pos = 1;
    lastPos = numel(array1);
    idx = zeros(size(array2));
    for a=1:numel(array2)
        while pos <= lastPos && array1(pos) <= array2(a)
            pos = pos + 1;
        end
        idx(a) = pos-1;
    end
    %idx(idx==0) = [];      %# in case min(array2) < min(array1)
    outArray = array1(idx);
end

function outArray = func4(array1,array2)
    [~,I] = sort([array1 array2]);
    a1size = numel(array1);
    J = find(I>a1size);
    outArray = nan(size(array2));
    for k=1:numel(J),
        if  I(J(k)-1)<=a1size,
            outArray(k) = array1(I(J(k)-1));
        else
            outArray(k) = outArray(k-1);
        end
    end
end

答案 1 :(得分:2)

速度缓慢的一个原因是,您要将array1中的所有元素与array2中的所有元素进行比较,以便它们分别包含MN个元素,复杂性为O(M*N)。但是,由于数组已经排序,因此有一个线性时间O(M+N)解决方案

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];

outArray = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);

ks = 1;
while ks <= n2 && array2(ks) < array1(1)
    ks = ks + 1;
end

for k2=ks:n2
    while k1 < n1 && array2(k2) >= array1(k1+1) 
        k1 = k1+1;
    end
    outArray(k2) = array1(k1);
end

这是一个测试用例,用于测量每个方法运行两个长度为500,000的数组所需的时间。

array2 = 1:500000;
array1 = array2-1;

tic
outArray1 = nan(size(array2));
k1 = 1;
n1 = numel(array1);
n2 = numel(array2);

ks = 1;
while ks <= n2 && array2(ks) < array1(1)
    ks = ks + 1;
end

for k2=ks:n2
    while k1 < n1 && array2(k2) >= array1(k1+1) 
        k1 = k1+1;
    end
    outArray1(k2) = array1(k1);
end
toc    

tic
outArray2 = nan(size(array2));
for a =1:numel(array2)
    outArray2(a) = array1(find(array1 <= array2(a),1,'last'));
end
toc

结果是

Elapsed time is 0.067637 seconds.
Elapsed time is 418.458722 seconds.

答案 2 :(得分:0)

注意: 这是我最初的解决方案,也是Amro答案的基准测试。但是,它比我在其他答案中提供的线性时间解决方案要慢。

速度缓慢的一个原因是您要将array1中的所有元素与array2中的所有元素进行比较,因此如果它们包含MN元素,则复杂性为O(M*N)。但是,您可以将它们连接起来并将它们排序在一起,并获得更快的复杂算法(M+N)*log2(M+N)。这是一种方法:

array2 = [5 6 18 25];
array1 = [1 5 9 15 22 24 31];

[~,I] = sort([array1 array2]);
a1size = numel(array1);
J = find(I>a1size);
outArray = nan(size(array2));
for k=1:numel(J),
    if  I(J(k)-1)<=a1size,
        outArray(k) = array1(I(J(k)-1));
    else
        outArray(k) = outArray(k-1);
    end
end

disp(outArray)

% Test using original code
outArray = nan(size(array2));
for a =1:numel(array2)
    outArray(a) = array1(find(array1 <= array2(a),1,'last'));
end
disp(outArray)

连接数组将是

>> [array1 array2]
ans =
     1     5     9    15    22    24    31     5     6    18    25

>> [B,I] = sort([array1 array2])
B =
     1     5     5     6     9    15    18    22    24    25    31
I =
     1     2     8     9     3     4    10     5     6    11     7

它显示在排序数组B中,第一个5来自连接数组中的第二个位置,第二个位于8个位置,依此类推。因此,要查找array1中小于array2中给定元素的最大元素,我们只需要遍历I中大于array1大小的所有索引(因此属于array2)并返回并找到属于array1的最近索引。 J包含这些元素在向量I中的位置:

>> J = find(I>a1size)
J =
     3     4     7    10

现在,for循环遍历这些索引并检查I之前的索引是否恰好在J引用的每个索引属于array1之前。如果它属于array1,则会从array1检索它,否则会复制为先前索引找到的值。

请注意,如果array2包含的元素小于array1中的最小元素,则代码和此代码都会失败。