从数组的每一行中减去多个向量(超广播)

时间:2013-06-18 20:40:35

标签: matlab octave vectorization bsxfun

我有一个X数据集m x 2,以及存储在C = [c1'; c2'; c3']矩阵3 x 2中的三个向量。我正在尝试对我的代码进行矢量化,为X中的每个数据点找到C中哪个向量最接近(平方距离)。我想从C中的每个向量(行)中减去X中的每个向量(行),从而导致m x 63m x 2矩阵之间的差异。 X以及C的元素。我当前的实现一次在X中执行这一行:

for i = 1:size(X, 1)
    diffs = bsxfun(@minus, X(i,:), C);    % gives a 3 x 2 matrix result
    [~, idx(i)] = min(sumsq(diffs), 2);   % returns the index of the closest vector
                                          % in C to the ith vector in X
end

我想摆脱这个for循环,只是对整个事物进行矢量化,但是bsxfun(@minus, X, C)在Octave中给了我一个错误:

  

错误:bsxfun:不一致的尺寸:300x2和3x2

我有什么想法可以“超广播”这两个矩阵之间的减法运算?

2 个答案:

答案 0 :(得分:5)

此问题的核心是计算大小为D的距离矩阵m x 3,其中包含X中所有数据点与{{1}中所有数据点之间的成对距离}。 C中第i个向量x_iX中第j个向量c_j之间的欧几里德距离可以重写为:

C

其中&lt;,&gt;指内在产品。这个等式的右边可以很容易地进行矢量化,因为所有对的内积只是|x_i-c_j|^2 = |x_i|^2 - 2<x_i, c_j> + |c_j|^2 ,这是BLAS3操作。这种计算距离矩阵的方法在Christopher Bishop的书籍模式识别和机器学习中被称为X * C'函数。我稍微修改了下面的功能。

dist2

function D = dist2(X, C) tempx = full(sum(X.^2, 2)); tempc = full(sum(C.^2, 2).'); D = -2*(X * C.'); D = bsxfun(@plus, D, tempx); D = bsxfun(@plus, D, tempc); 此处用于fullX为稀疏矩阵的情况。

注意:由于数字舍入误差,以这种方式计算的距离矩阵C可能会有微小的负数条目。为了防范这种情况,请使用

D

D = max(D, 0); 中最接近的向量的索引可以从C

中检索
D

答案 1 :(得分:0)

如果您有统计工具箱,则可以使用pdist2

  

PDIST2两组观测值之间的成对距离。       D = PDIST2(X,Y)返回包含欧几里德距离的矩阵D.       在MX-by-N数据矩阵X和中的每对观察之间       MY-by-N数据矩阵Y.

所以在你的情况下,

[~, which_C] = min(pdist2(X,C), [], 2);

是您正在寻找的。

或者,你可以使用这种美:

[~, which_c] = min(sum(bsxfun(@minus, X, permute(C, [3 2 1])).^2, 2), [], 3);

在可读性,稳健性或可管理性方面不会赢得任何奖项,但你会获得一些速度(并且需要一个工具箱,请注意:)