矢量化的数组比较

时间:2014-01-22 15:26:39

标签: matlab vectorization jit

我有通过将数组MatchesX.trainIdx中的一个元素与第二个数组MatchesX.queryIdx中的一个或多个元素匹配而创建的主观函数。

为了只获得所述函数的双射元素,我运行相同的函数

Matches1=Matcher.match(Descriptors1,Descriptors2);

然后倒退

Matches2=Matcher.match(Descriptors2,Descriptors1);

然后按以下方式查找两个函数中出现的元素:

k=1;
DoubleMatches=Matches1;

for i=1:length(Matches1)
    for j=1:length(Matches2)
        if((Matches1(i).queryIdx==Matches2(j).trainIdx)&&(Matches1(i).trainIdx==Matches2(j).queryIdx))
            DoubleMatches(k)=Matches1(i);
            k=k+1;
       end
   end
end

DoubleMatches(k:end)=[];

这当然是有效的,但它相当不优雅,似乎打扰了JIT加速器(accel onaccel off的计算时间相同)。

你能想出一种对这种表达进行矢量化的方法吗?有没有其他办法可以避免JIT“罢工”?

非常感谢并且抱歉奇怪的结构,我正在使用MEX功能。如果重写“普通”数组中的代码会让我知道

2 个答案:

答案 0 :(得分:2)

在MATLAB中访问多维结构中的数据非常慢,因此将数据转换为普通数组肯定会有所帮助:

kk = 1;
DoubleMatches = Matches1;

%// transform to regular array
Matches1queryIdx = [Matches1.queryIdx];
Matches1trainIdx = [Matches1.trainIdx];

Matches2queryIdx = [Matches2.queryIdx];
Matches2trainIdx = [Matches2.trainIdx];

%// loop through transformed data instead of structures
for ii = 1:length(Matches1queryIdx)
    for jj = 1:length(Matches1queryIdx)
        if((Matches1queryIdx(ii)==Matches2trainIdx(jj)) && ...
                (Matches1trainIdx(ii)==Matches2queryIdx(jj)))
            DoubleMatches(kk) = Matches1(ii);
            kk = kk+1;
        end
    end
end

DoubleMatches(kk:end)=[];

还有一种几乎完全被矢量化的解决方案:

matches = sum(...
    bsxfun(@eq, [Matches1.queryIdx], [Matches2.trainIdx].') & ...
    bsxfun(@eq, [Matches1.trainIdx], [Matches2.queryIdx].'));

contents = arrayfun(@(x)..
    repmat(Matches1(x),1,matches(x)), 1:numel(matches), ...
    'Uniformoutput', false);

DoubleMatches2 = [contents{:}]';

请注意,这可能会占用大量内存(它具有O(N²)峰值内存占用,而其他内容则为O(N),尽管峰值内存中的数据类型为logical,因此比double小8倍...)。最好事先做一些检查,你应该使用哪一个。

一点点测试。我使用了以下虚拟数据:

Matches1 = struct(...
    'queryIdx', num2cell(randi(25,1000,1)),...
    'trainIdx', num2cell(randi(25,1000,1))...
);

Matches2 = struct(...
    'queryIdx', num2cell(randi(25,1000,1)),...
    'trainIdx', num2cell(randi(25,1000,1))...
);

以及以下测试:

%// Your original method
tic    
    kk = 1;
    DoubleMatches = Matches1;

    for ii = 1:length(Matches1)
        for jj = 1:length(Matches2)
            if((Matches1(ii).queryIdx==Matches2(jj).trainIdx) && ...
                    (Matches1(ii).trainIdx==Matches2(jj).queryIdx))
                DoubleMatches(kk) = Matches1(ii);
                kk = kk+1;
            end
        end
    end

    DoubleMatches(kk:end)=[];

toc

DoubleMatches1 = DoubleMatches;


%// Method with data transformed into regular array
tic

    kk = 1;
    DoubleMatches = Matches1;

    Matches1queryIdx = [Matches1.queryIdx];
    Matches1trainIdx = [Matches1.trainIdx];

    Matches2queryIdx = [Matches2.queryIdx];
    Matches2trainIdx = [Matches2.trainIdx];

    for ii = 1:length(Matches1queryIdx)
        for jj = 1:length(Matches1queryIdx)
            if((Matches1queryIdx(ii)==Matches2trainIdx(jj)) && ...
                    (Matches1trainIdx(ii)==Matches2queryIdx(jj)))
                DoubleMatches(kk) = Matches1(ii);
                kk = kk+1;
            end
        end
    end

    DoubleMatches(kk:end)=[];

toc

DoubleMatches2 = DoubleMatches;


% // Vectorized method
tic

    matches = sum(...
        bsxfun(@eq, [Matches1.queryIdx], [Matches2.trainIdx].') & ...
        bsxfun(@eq, [Matches1.trainIdx], [Matches2.queryIdx].'));

    contents = arrayfun(@(x)repmat(Matches1(x),1,matches(x)), 1:numel(matches), 'Uniformoutput', false);

    DoubleMatches3 = [contents{:}]';

toc

%// Check if all are equal
isequal(DoubleMatches1,DoubleMatches2, DoubleMatches3)

结果:

Elapsed time is 6.350679 seconds. %// (  1×) original method
Elapsed time is 0.636479 seconds. %// (~10×) method with regular array
Elapsed time is 0.165935 seconds. %// (~40×) vectorized
ans =
     1                            %// indeed, outcomes are equal

答案 1 :(得分:0)

假设Matcher.match返回与传递给它的相同对象的数组作为参数,你可以像这样解决这个问题

% m1 are all d1s which have relation to d2
m1 = Matcher.match(d1,d2);
% m2 are all d2s, which have relation to m1
% and all m1 already have backward relation
m2 = Matcher.match(d2,m1);