我有一个X
数据集m x 2
,以及存储在C = [c1'; c2'; c3']
矩阵3 x 2
中的三个向量。我正在尝试对我的代码进行矢量化,为X
中的每个数据点找到C
中哪个向量最接近(平方距离)。我想从C
中的每个向量(行)中减去X
中的每个向量(行),从而导致m x 6
或3m x 2
矩阵之间的差异。 X
以及C
的元素。我当前的实现一次在X
中执行这一行:
for i = 1:size(X, 1)
diffs = bsxfun(@minus, X(i,:), C); % gives a 3 x 2 matrix result
[~, idx(i)] = min(sumsq(diffs), 2); % returns the index of the closest vector
% in C to the ith vector in X
end
我想摆脱这个for
循环,只是对整个事物进行矢量化,但是bsxfun(@minus, X, C)
在Octave中给了我一个错误:
错误:bsxfun:不一致的尺寸:300x2和3x2
我有什么想法可以“超广播”这两个矩阵之间的减法运算?
答案 0 :(得分:5)
此问题的核心是计算大小为D
的距离矩阵m x 3
,其中包含X
中所有数据点与{{1}中所有数据点之间的成对距离}。 C
中第i个向量x_i
与X
中第j个向量c_j
之间的欧几里德距离可以重写为:
C
其中&lt;,&gt;指内在产品。这个等式的右边可以很容易地进行矢量化,因为所有对的内积只是|x_i-c_j|^2 = |x_i|^2 - 2<x_i, c_j> + |c_j|^2
,这是BLAS3操作。这种计算距离矩阵的方法在Christopher Bishop的书籍模式识别和机器学习中被称为X * C'
函数。我稍微修改了下面的功能。
dist2
function D = dist2(X, C)
tempx = full(sum(X.^2, 2));
tempc = full(sum(C.^2, 2).');
D = -2*(X * C.');
D = bsxfun(@plus, D, tempx);
D = bsxfun(@plus, D, tempc);
此处用于full
或X
为稀疏矩阵的情况。
注意:由于数字舍入误差,以这种方式计算的距离矩阵C
可能会有微小的负数条目。为了防范这种情况,请使用
D
D = max(D, 0);
中最接近的向量的索引可以从C
:
D
答案 1 :(得分:0)
如果您有统计工具箱,则可以使用pdist2
:
PDIST2两组观测值之间的成对距离。 D = PDIST2(X,Y)返回包含欧几里德距离的矩阵D. 在MX-by-N数据矩阵X和中的每对观察之间 MY-by-N数据矩阵Y.
所以在你的情况下,
[~, which_C] = min(pdist2(X,C), [], 2);
是您正在寻找的。 p>
或者,你可以使用这种美:
[~, which_c] = min(sum(bsxfun(@minus, X, permute(C, [3 2 1])).^2, 2), [], 3);
在可读性,稳健性或可管理性方面不会赢得任何奖项,但你会获得一些速度(并且需要一个工具箱,请注意:)