是否有可能优化这个Matlab代码,用k-means进行质心矢量量化?

时间:2011-04-20 21:06:10

标签: optimization matlab vector k-means quantization

我使用大小为4000x300的k-means创建了一个码本(4000个质心,每个都有300个功能)。然后,我使用代码簿标记输入向量(以便稍后进行分级)。输入向量的大小为Nx300,其中N是我收到的输入实例的总数。

为了计算标签,我计算每个输入向量的最近质心。为此,我将每个输入矢量与所有质心进行比较,并选择具有最小距离的质心。然后标签就是该质心的索引。

我目前的Matlab代码如下:

function labels = assign_labels(centroids, X)
labels = zeros(size(X, 1), 1);

% for each X, calculate the distance from each centroid
for i = 1:size(X, 1)
    % distance of X_i from all j centroids is: sum((X_i - centroid_j)^2)
    % note: we leave off the sqrt as an optimization
    distances = sum(bsxfun(@minus, centroids, X(i, :)) .^ 2, 2);
    [value, label] = min(distances);
    labels(i) = label;
end     

但是,这段代码仍然相当慢(为了我的目的),我希望有可能进一步优化代码。

一个明显的问题是有一个for循环,这是Matlab上性能良好的祸根。我一直试图找到摆脱它的方法,但没有运气(我调查使用arrayfun与bsxfun,但没有得到这个工作)。或者,如果有人知道任何其他方法来加快这一点,我将非常感激。

更新

在做了一些搜索之后,我找不到使用Matlab的很好的解决方案,所以我决定查看Python的scikits.learn包中用于'euclidean_distance'(缩短)的内容:

 XX = sum(X * X, axis=1)[:, newaxis]
 YY = Y.copy()
 YY **= 2
 YY = sum(YY, axis=1)[newaxis, :]
 distances = XX + YY
 distances -= 2 * dot(X, Y.T)
 distances = maximum(distances, 0)

使用欧几里德距离的二项式形式((x-y)^ 2 - > x ^ 2 + y ^ 2 - 2xy),从我读过的内容通常运行得更快。我完全未经测试的Matlab翻译是:

 XX = sum(data .* data, 2);
 YY = sum(center .^ 2, 2);
 [val, ~] = max(XX + YY - 2*data*center');

4 个答案:

答案 0 :(得分:4)

使用以下功能计算距离。你应该看到一个数量级的加速

两个矩阵A和B的列为Dimenions,行为每个点。 A是你的质心矩阵。 B是您的数据点矩阵。

function D=getSim(A,B)
    Qa=repmat(dot(A,A,2),1,size(B,1));
    Qb=repmat(dot(B,B,2),1,size(A,1));
    D=Qa+Qb'-2*A*B';

答案 1 :(得分:1)

您可以通过转换为单元格并使用cellfun来对其进行矢量化:

[nRows,nCols]=size(X);
XCell=num2cell(X,2);
dist=reshape(cell2mat(cellfun(@(x)(sum(bsxfun(@minus,centroids,x).^2,2)),XCell,'UniformOutput',false)),nRows,nRows);
[~,labels]=min(dist);

<强>解释

  • 我们将X的每一行分配到第二行中的自己的单元格
  • 此作品@(x)(sum(bsxfun(@minus,centroids,x).^2,2))是一个匿名函数,与您的distances=...行相同,使用cell2mat,我们将其应用于X的每一行。
  • 然后标签是每列最小行的索引。

答案 2 :(得分:1)

对于真正的矩阵实现,您可以考虑尝试以下方式:

  P2 = kron(centroids, ones(size(X,1),1));
  Q2 = kron(ones(size(centroids,1),1), X);

  distances = reshape(sum((Q2-P2).^2,2), size(X,1), size(centroids,1));

注意 假设数据组织为[x1 y1 ...; x2 y2 ...; ...]

答案 3 :(得分:1)

您可以使用更有效的算法进行最近邻搜索而不是强力搜索。 最流行的方法是Kd-Tree。 O(log(n))平均查询时间而不是O(n)暴力复杂度。 关于马耳他实施的Kd-Trees,您可以查看here