在matlab矩阵中查找子矩阵的一般方法

时间:2013-09-16 14:51:13

标签: matlab matrix find

我正在寻找一种'好'的方法来在更大的矩阵(任意数量的维度)中找到矩阵(模式)。

示例:

total = rand(3,4,5);
sub = total(2:3,1:3,3:4);

现在我希望这种情况发生:

loc = matrixFind(total, sub)

在这种情况下,loc应该成为[2 1 3]

现在我只想找到一个单点(如果它存在)并且不担心舍入问题。可以假设sub'适合'total


以下是我如何为3个维度做到这一点,但它只是感觉有更好的方法:

total = rand(3,4,5);
sub = total(2:3,1:3,3:4);
loc = [];
for x = 1:size(total,1)-size(sub,1)+1
    for y = 1:size(total,2)-size(sub,2)+1
        for z = 1:size(total,3)-size(sub,3)+1
            block = total(x:x+size(sub,1)-1,y:y+size(sub,2)-1,z:z+size(sub,3)-1);
            if isequal(sub,block)
                loc = [x y z]
            end
        end
    end
end

我希望为任意数量的维度找到可行的解决方案。

3 个答案:

答案 0 :(得分:3)

这是低性能,但(据称)任意维度函数。它使用find创建total中潜在匹配位置的(线性)索引列表,然后检查total的适当大小的子块是否与sub匹配。

function loc = matrixFind(total, sub)
%matrixFind find position of array in another array

    % initialize result
    loc = [];

    % pre-check: do all elements of sub exist in total?
    elements_in_both = intersect(sub(:), total(:));
    if numel(elements_in_both) < numel(unique(sub))
        % if not, return nothing
        return
    end

    % select a pivot element
    % Improvement: use least common element in total for less iterations
    pivot_element = sub(1);

    % determine linear index of all occurences of pivot_elemnent in total
    starting_positions = find(total == pivot_element);

    % prepare cell arrays for variable length subscript vectors
    [subscripts, subscript_ranges] = deal(cell([1, ndims(total)]));


    for k = 1:length(starting_positions)
        % fill subscript vector for starting position
        [subscripts{:}] = ind2sub(size(total), starting_positions(k));

        % add offsets according to size of sub per dimension
        for m = 1:length(subscripts)
            subscript_ranges{m} = subscripts{m}:subscripts{m} + size(sub, m) - 1;
        end

        % is subblock of total equal to sub
        if isequal(total(subscript_ranges{:}), sub)
            loc = [loc; cell2mat(subscripts)]; %#ok<AGROW>
        end
    end
end

答案 1 :(得分:2)

这是基于对原始矩阵total进行所有可能的移位,并将移位的total的最左上等子矩阵与所寻找的模式subs进行比较。使用字符串生成移位,并使用circshift来应用。

大部分工作都是矢量化的。只使用一级循环。

该函数查找所有匹配项,而不仅仅是第一项。例如:

>> total = ones(3,4,5,6);
>> sub = ones(3,3,5,6);
>> matrixFind(total, sub)
ans =

     1     1     1     1
     1     2     1     1

这是功能:

function sol = matrixFind(total, sub)

nd = ndims(total);
sizt = size(total).';
max_sizt = max(sizt);
sizs = [ size(sub) ones(1,nd-ndims(sub)) ].'; % in case there are
% trailing singletons

if any(sizs>sizt)
    error('Incorrect dimensions')
end

allowed_shift = (sizt-sizs);
max_allowed_shift = max(allowed_shift);
if max_allowed_shift>0
    shifts = dec2base(0:(max_allowed_shift+1)^nd-1,max_allowed_shift+1).'-'0';
    filter = all(bsxfun(@le,shifts,allowed_shift));
    shifts = shifts(:,filter); % possible shifts of matrix "total", along 
    % all dimensions
else
    shifts = zeros(nd,1);
end

for dim = 1:nd
    d{dim} = 1:sizt(dim); % vectors with subindices per dimension
end
g = cell(1,nd);
[g{:}] = ndgrid(d{:}); % grid of subindices per dimension
gc = cat(nd+1,g{:}); % concatenated grid
accept = repmat(permute(sizs,[2:nd+1 1]), [sizt; 1]); % acceptable values
% of subindices in order to compare with matrix "sub"
ind_filter = find(all(gc<=accept,nd+1));

sol = [];
for shift = shifts
    total_shifted = circshift(total,-shift);
    if all(total_shifted(ind_filter)==sub(:))
        sol = [ sol; shift.'+1 ];
    end
end

答案 2 :(得分:1)

对于任意数量的维度,您可以尝试convn

C = convn(total,reshape(sub(end:-1:1),size(sub)),'valid'); % flip dimensions of sub to be correlation
[~,indmax] = max(C(:));
% thanks to Eitan T for the next line
cc = cell(1,ndims(total)); [cc{:}] = ind2sub(size(C),indmax); subs = [cc{:}]

感谢Eitan T建议对广义的ind2sub使用逗号分隔列表。

最后,您应该使用isequal测试结果,因为这不是标准化的互相关,这意味着本地子区域中的较大数字会使相关值膨胀,从而可能产生误报。如果您的total矩阵与大值区域非常不同,则可能需要在C中搜索其他最大值。