我为冗长的标题道歉。
请考虑以下代码,这些代码可能会出现在k-means
群集中。
rng(1)
num_samples = 10;
samples = randi(100, num_samples, 3);
cluster_centroids = randi(16,3);
cluster_indices = zeros(num_samples,1);
for index = 1:num_samples
distances = sqrt(sum((samples(index) - cluster_centroids).^2, 2));
cluster = find(distances == min(distances), 1)
cluster_indices(index) = cluster;
end
有没有办法对它进行矢量化并删除for循环,这样我们才能有效地处理所有样本(三个整数的元组)?
答案 0 :(得分:2)
现有代码有几个问题,我在底部列出,但首先是矢量化答案。答案的核心类似于my past solution。我已经稍微修改了一下你的变量,并添加了一些解释:
[nPoints,nDims] = size(samples);
k = 3; % ? size(cluster_centroids,1)
% Calculate all high-dimensional distances at once
% (NxDx1 - 1xDxK => NxDxK)
kdiffs = bsxfun(@minus,samples,permute(cluster_centroids,[3 2 1]));
distances = sum(kdiffs.^2,2); % no need to do sqrt
distances = squeeze(distances); % Nx1xK => NxK
distances
是向量化的重要值。其余的相当简单:
% Find closest cluster center for each point
[~,cluster_indices] = min(distances,[],2); % Nx1
然后,您需要更新集群中心以进行以下迭代:
cluster_centroids_new = zeros(k,nDims);
for i=1:k,
indk = cluster_indices==i;
clustersizes(i) = nnz(indk);
cluster_centroids_new(i,:) = mean(samples(indk,:))';
end
要为此代码添加一些讨论,请注意cluster_centroids_new
每个群集都有一行,但如果群集没有成员,则该行将为NaN
s。
有问题的代码问题:
cluster_indices
被初始化为矩阵而不是向量。修复:cluster_indices = zeros(num_samples,1);
(如horchler所述)。samples
计算中正确编制distances
索引。将samples(index)
更改为samples(index,:)
以提取样本。cluster_centroids
之间的距离需要为sum((repmat(samples(index,:),size(cluster_centroids,1),1) - cluster_centroids).^2, 2)
,才能计算样本与所有k=size(cluster_centroids,1)
群集的距离。cluster_centroids = randi(16,3);
,它为3x3矩阵提供1到16之间的数字。