如果我有两个大小为[m x n]和[p x n]的矩阵A和B,我想找到每一行B出现在A中的次数,例如:
>> A = rand(5,3)
A =
0.1419 0.6557 0.7577
0.4218 0.0357 0.7431
0.9157 0.8491 0.3922
0.7922 0.9340 0.6555
0.9595 0.6787 0.1712
>> B = [A(2,:); A(1,:); A(2,:); A(3,:); A(3,:); A(4,:); A(5,:)]
B =
0.4218 0.0357 0.7431
0.1419 0.6557 0.7577
0.4218 0.0357 0.7431
0.9157 0.8491 0.3922
0.9157 0.8491 0.3922
0.7922 0.9340 0.6555
0.9595 0.6787 0.1712
在这种情况下的答案是
ans =
1 2 2 1 1
虽然与此示例不同,但通常m>> P
如果A和B是向量,matlab的histc就可以完成这项任务,但是如果这些是向量的话,似乎没有等效的。
目前我是这样做的:
for i=1:length(B)
indices(i) = find(abs(A/B(i,:)-1) < 1e-15);
% division requires a tolerance due to numerical issues
end
histc(indices, 1:size(A,1))
ans =
1 2 2 1 1
但由于我有很多这样的矩阵B,A和B都很大,所以速度非常慢。有任何想法如何改进吗?
编辑:
到目前为止看看这些方法,我有以下数据:
A 7871139x3 188907336 double
B 902x3 21648 double
为了让事情更快,我将使用前10行B
B = B(1:10,:);
请注意,对于完整的应用程序,我(目前)有> 10 ^ 4个这样的矩阵(这最终将> 10 ^ 6 ....)
我的第一种方法:
tic, C = get_vector_index(A,B); toc
Elapsed time is 36.630107 seconds.
bdecaf的方法(通过删除if
语句并使用L1距离而不是L2距离可以缩短到~25秒)
>> tic, C1 = get_vector_index(A,B); toc
Elapsed time is 28.957243 seconds.
>> isequal(C, C1)
ans =
1
oli的pdist2方法
>> tic, C2 = get_vector_index(A,B); toc
Elapsed time is 7.244965 seconds.
>> isequal(C,C2)
ans =
1
oli的规范化方法
>> tic, C3 = get_vector_index(A,B); toc
Elapsed time is 3.756682 seconds.
>> isequal(C,C3)
ans =
1
最后,我想出了另一种方法,在那里我搜索第一列,然后在第一列的点击内搜索第二列,递归直到列耗尽。这是目前为止最快的......
N = size(A,2);
loc = zeros(size(B,1),1);
for i=1:size(B,1)
idx{1} = find(A(:,1)==B(i,1));
for k=2:N,
idx{k} = idx{k-1}(find(A(idx{k-1},k)==B(i,k)));
end
loc(i) = idx{end};
end
C = histc(loc, 1:size(A,1));
导致:
>> tic, C4 = get_vector_index(A,B); toc
Elapsed time is 1.314370 seconds.
>> isequal(C, C4)
ans =
1
另请注意,使用intersect
要慢得多:
>> tic, [~,IA] = intersect(A,B,'rows'); C5 = histc(IA,1:size(A,1)); toc
Elapsed time is 44.392593 seconds.
>> isequal(C,C5)
ans =
1
答案 0 :(得分:1)
也许你可以将它们标准化,以便检查它们的点积是1
A = rand(5,3);
B = [A(2,:); A(1,:); A(2,:); A(3,:); A(3,:); A(4,:); A(5,:)];
A2=bsxfun(@times,A,1./sqrt(sum(A.^2,2))); %%% normalize A
B2=bsxfun(@times,B,1./sqrt(sum(B.^2,2))) %%% normalize B
sum(A2*B2'>1-10e-9,2) %%% check that the dotproduct is close to 1
ans =
1
2
2
1
1
如果您需要更快但近似的东西,我建议您使用flann库,这是快速近似的最近邻居:
答案 1 :(得分:1)
我会这样解决:
indices = zeros(size(A,1),1);
for i=1:size(B,1)
distances = sum( ( repmat(B(i,:),size(A,1),1)-A ).^2 ,2);
[md,im]=min(distances);
if md < 1e-9
indices(im) = indices(im)+1;
end
end
如果你删除它,它只会排序到最近的bin。
答案 2 :(得分:1)
实际上,更简单的方法是:
sum(10e-9>pdist2(A',B'),2)
它计算所有成对距离和阈值和计数。
答案 3 :(得分:0)
我实际上把这个解决方案作为问题的编辑,但为了接受答案,我也将解决方案放在这里:
N = size(A,2);
loc = zeros(size(B,1),1);
for i=1:size(B,1)
idx{1} = find(A(:,1)==B(i,1));
for k=2:N,
idx{k} = idx{k-1}(find(A(idx{k-1},k)==B(i,k)));
end
loc(i) = idx{end};
end
C = histc(loc, 1:size(A,1));
导致:
>> tic, C4 = get_vector_index(A,B); toc
Elapsed time is 1.314370 seconds.
>> isequal(C, C4)
ans =
1