我有两个ND数组(10 * 4 * 4 * 2 * 4),我在第一个中查找每列的最大值,并希望从第二个中提取相同位置的元素。
为简单起见,我们假设我有A和B如下:
A = randi(100,4,3,2);
B = randi(10,size(A));
现在,我用:
查找最大元素的索引[~,ind] = max(A);
然后我想使用ind
从B
中提取元素。最佳选项类似C = B(ind)
,但显然不起作用。 Matlab将ind
称为线性索引,就像我写C = B(ind(:))
一样,除了C
的大小变得像ind
(即1 * 3 * 2)。
所以我也试过这样的东西(我发现here):
m = {ind,':',':'};
C = B(m{:})
但是这会产生一个大小为6 * 3 * 2的数组,其中ind
是B
中每列的线性索引。当然我可以使用循环来做到这一点,但我确信有一种更有效和更优雅的方式。
示例:
A(:,:,1) =
40 89 30
73 77 30
59 61 14
29 2 30
A(:,:,2) =
82 79 5
3 40 62
46 76 42
22 52 74
B(:,:,1) =
5 1 9
3 3 5
8 4 4
8 3 9
B(:,:,2) =
1 4 3
5 4 8
10 8 5
9 1 3
ind(:,:,1) =
2 1 1
ind(:,:,2) =
1 1 4
所以结果应该是:
C =
3 1
1 4
9 3
请注意,我寻找一般解决方案,而不是仅适用于此示例中的维度的解决方案。
答案 0 :(得分:1)
对于三维,您可以使用you do it in config.yml将(子)索引转换为线性索引:
[~, ind] = max(A,[],1);
linind = sub2ind(size(A), reshape(ind, size(A,2), size(A,3)), ...
repmat((1:size(A,2)).', 1, size(A,3)), ...
repmat(1:size(A,3), size(A,2), 1));
C = B(linind);
对于任意数量的维度,在调用max
之前,可以更容易地将第一个维度之外的所有维度折叠为一个维度。这也可以通过更快速的手动计算更轻松地替换sub2ind
:
Ar = reshape(A, size(A,1), []); % collapse all dimensions beyond the first
[~, ind] = max(Ar,[],1); % arg max of each column
linind = ind + (0:size(Ar,2)-1)*size(Ar,1); % convert to linear indices
C = B(linind); % index into C
sz = size(A); % size of A
C = reshape(C, sz(2:end)); % reshape C according to shape of A