多标签多类数据的平均平均精度

时间:2017-12-28 06:59:28

标签: matlab performance metric

我正在尝试编写用于计算多标签数据的平均平均精度(MAP)的代码。为了更直观地理解,请看下面

enter image description here

我在MATLAB中编写了MAP计算的代码,但速度很慢。基本上它是,因为对于 r 的每个值计算变量 Lrx

我想让我的代码更快。

function [map] = map_at_R(sim_x,L_tr,L_te)

%sim_x(i,j) denote the sim bewteen query j and database i
tn = size(sim_x,2);
APx = zeros(tn,1);
R = 100;

for i = 1 : tn    
    Px = zeros(R,1);
    deltax = zeros(R,1);
    label = L_te(i,:);
    [~,inxx] = sort(sim_x(:,i),'descend');

    % compute Lx - the denominator in the map calculation
    % Lx = 1 if the retrieved item has the same label with the query or
    % shares atleast one label else Lx = 0
    search_set = L_tr(inxx(1:R),:);

    for r = 1 : R        
        %% FAST COMPUTATION
        Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);

        %% SLOW COMPUTATION
%         Lrx = 0;
%         for j=1:r
%             if sum(label*(search_set(j,:)).')>0
%                 Lrx = Lrx+1;
%             end
%         end        

        if sum(label*(search_set(r,:)).')>0
            deltax(r) = 1;
        end

        Px(r) = Lrx/r;
    end
    Lx = sum(deltax);
    if Lx ~=0
        APx(i) = sum(Px.*deltax)/Lx;
    end
end
map = mean(APx);

代码的输入是:

% sim_x = similarity score matrix or distance matrix
sim_x = gallery_data_size X probe_data_size 

% L_tr = labels of the gallery set
L_tr = gallery_data_size X c

% L_te = labels of the probe set
L_te = probe_data_size X c

% where c is the number of classes
% please note that the data is multi-label

是否可以使代码更快?我自己无法弄明白。

1 个答案:

答案 0 :(得分:2)

使用delta函数APx(i) = sum(Px.*deltax)/Lx,您将丢弃一些r = 1:R次迭代。由于可以在循环之前定义增量,为什么不只迭代r deltax(r) == 1

% r_range is equivalent to find(deltax(r) == 1);
%Edit 1/4 %Previously :: r_range = find(sum(label*(search_set(1:R,:)).')>0);
% Multiply each row by label
mult = bsxfun(@times,(search_set(1:R,:)),label);
% Sum each row 
r_range = find(sum(mult,2)>0);
% r_range @ i should equal find(deltax) @ i

Px = zeros(numel(r_range,1);

for r = r_range
    Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);
    Px(r == r_range) = Lrx/r;
end 

Lx = numel(r_range);
if Lx ~=0
    APx(i) = sum(Px)/Lx;
end