我正在尝试编写用于计算多标签数据的平均平均精度(MAP)的代码。为了更直观地理解,请看下面
我在MATLAB中编写了MAP计算的代码,但速度很慢。基本上它是慢,因为对于 r 的每个值计算变量 Lrx 。
我想让我的代码更快。
function [map] = map_at_R(sim_x,L_tr,L_te)
%sim_x(i,j) denote the sim bewteen query j and database i
tn = size(sim_x,2);
APx = zeros(tn,1);
R = 100;
for i = 1 : tn
Px = zeros(R,1);
deltax = zeros(R,1);
label = L_te(i,:);
[~,inxx] = sort(sim_x(:,i),'descend');
% compute Lx - the denominator in the map calculation
% Lx = 1 if the retrieved item has the same label with the query or
% shares atleast one label else Lx = 0
search_set = L_tr(inxx(1:R),:);
for r = 1 : R
%% FAST COMPUTATION
Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);
%% SLOW COMPUTATION
% Lrx = 0;
% for j=1:r
% if sum(label*(search_set(j,:)).')>0
% Lrx = Lrx+1;
% end
% end
if sum(label*(search_set(r,:)).')>0
deltax(r) = 1;
end
Px(r) = Lrx/r;
end
Lx = sum(deltax);
if Lx ~=0
APx(i) = sum(Px.*deltax)/Lx;
end
end
map = mean(APx);
代码的输入是:
% sim_x = similarity score matrix or distance matrix
sim_x = gallery_data_size X probe_data_size
% L_tr = labels of the gallery set
L_tr = gallery_data_size X c
% L_te = labels of the probe set
L_te = probe_data_size X c
% where c is the number of classes
% please note that the data is multi-label
是否可以使代码更快?我自己无法弄明白。
答案 0 :(得分:2)
使用delta函数APx(i) = sum(Px.*deltax)/Lx
,您将丢弃一些r = 1:R
次迭代。由于可以在循环之前定义增量,为什么不只迭代r
deltax(r) == 1
。
% r_range is equivalent to find(deltax(r) == 1);
%Edit 1/4 %Previously :: r_range = find(sum(label*(search_set(1:R,:)).')>0);
% Multiply each row by label
mult = bsxfun(@times,(search_set(1:R,:)),label);
% Sum each row
r_range = find(sum(mult,2)>0);
% r_range @ i should equal find(deltax) @ i
Px = zeros(numel(r_range,1);
for r = r_range
Lrx = sum(diag(repmat(label,r,1)*search_set(1:r,:).')>0);
Px(r == r_range) = Lrx/r;
end
Lx = numel(r_range);
if Lx ~=0
APx(i) = sum(Px)/Lx;
end