关于使用Simulink训练自组织映射(SOM)中数据点的移动可视化

时间:2013-09-11 20:58:14

标签: matlab machine-learning neural-network self-organizing-maps

我在MATLAB中实现了自组织映射(SOM)算法。假设每个数据点都在二维空间中表示。问题是我想要在训练阶段可视化每个数据点的运动,即我想看到点如何移动并最终形成簇,因为算法正在进行中,例如在每个固定持续时间。我相信这可以通过MATLAB中的Simulation来完成,但我不知道如何将我的MATLAB代码用于可视化?

1 个答案:

答案 0 :(得分:2)

我开发了一个代码示例,使用二维中所有可能的数据投影可视化具有多个维度的聚类数据。它可能不是可视化的最佳选择(有为此开发的技术,因为SOM本身可能用于此需求),特别是对于更高维数,但是当可能的投影数(n-1)!不是它是一个非常好的可视化器。


聚类算法

由于我需要访问代码以便我可以为每次迭代保存集群均值和集群标签,因此a fast kmeans algorithm available at FEX使用了Mo Chen,但我必须对其进行调整以便我可以使用访问。改编的代码如下:

function [label,m] = litekmeans(X, k)
% Perform k-means clustering.
%   X: d x n data matrix
%   k: number of seeds
% Written by Michael Chen (sth4nth@gmail.com).
n = size(X,2);
last = 0;
iter = 1;
label{iter} = ceil(k*rand(1,n));  % random initialization
checkLabel = label{iter};
m = {};
while any(checkLabel ~= last)
    [u,~,checkLabel] = unique(checkLabel);   % remove empty clusters
    k = length(u);
    E = sparse(1:n,checkLabel,1,n,k,n);  % transform label into indicator matrix
    curM = X*(E*spdiags(1./sum(E,1)',0,k,k));    % compute m of each cluster
    m{iter} = curM;
    last = checkLabel';
    [~,checkLabel] = max(bsxfun(@minus,curM'*X,dot(curM,curM,1)'/2),[],1); % assign samples to the nearest centers
    iter = iter + 1;
    label{iter} = checkLabel;
end
% Get last clusters centers
m{iter} = curM;
% If to remove empty clusters:
%for k=1:iter
%  [~,~,label{k}] = unique(label{k});
%end

Gif Creation

我还使用@Amro's Matlab video tutorial来创建gif。

可区分的颜色

我使用this great FEX Tim Holy来更轻松地区分群集颜色。

产生的代码

我的结果代码如下。我遇到了一些问题,因为每次迭代都会改变簇的数量,这会导致散点图更新以删除所有簇中心而不会出现任何错误。由于我没有注意到这一点,我试图使用任何模糊的方法解决散射函数,我可以找到网页(顺便说一句,我发现了一个非常好的散点图替代here),但幸运的是我得到的是什么发生在今天回到这个。这是我为它做的代码,你可以随意使用它,适应它,但如果你使用它请保留我的参考。

function varargout=kmeans_test(data,nClusters,plotOpts,dimLabels,...
  bigXDim,bigYDim,gifName)
%
% [label,m,figH,handles]=kmeans_test(data,nClusters,plotOpts,...
%   dimLabels,bigXDim,bigYDim,gifName)
% Demonstrate kmeans algorithm iterative progress. Inputs are:
%
% -> data (rand(5,100)): the data to use.
%
% -> nClusters (7): number of clusters to use.
%
% -> plotOpts: struct holding the following fields:
%
%   o leftBase: the percentage distance from the left
%
%   o rightBase: the percentage distance from the right
%
%   o bottomBase: the percentage distance from the bottom
%
%   o topBase: the percentage distance from the top
%
%   o FontSize: FontSize for axes labels.
%
%   o widthUsableArea: Total width occupied by axes
%
%   o heigthUsableArea: Total heigth occupied by axes
%
% -> bigXDim (1): the big subplot x dimension
%
% -> bigYDim (2): the big subplot y dimension
%
% -> dimLabels: If you want to specify dimensions labels
%
% -> gifName: gif file name to save
%
% Outputs are:
% 
% -> label: Sample cluster center number for each iteration
%
% -> m: cluster center mean for each iteration
%
% -> figH: figure handle
%
% -> handles: axes handles
%

%
% - Creation Date: Fri, 13 Sep 2013 
% - Last Modified: Mon, 16 Sep 2013 
% - Author(s): 
%   - W.S.Freund <wsfreund_at_gmail_dot_com> 

%
% TODO List (?):
%
%  - Use input parser 
%  - Adapt it to be able to cluster any algorithm function.
%  - Use arrows indicating cluster centers movement before moving them.
%  - Drag and drop small axes to big axes.
%

% Pre-start
if nargin < 7
  gifName = 'kmeansClusterization.gif';
  if nargin < 6
    bigYDim = 2;
    if nargin < 5
      bigXDim = 1;
      if nargin < 4
        nDim = size(data,1);
        maxDigits = numel(num2str(nDim));
        dimLabels = mat2cell(sprintf(['Dim %0' num2str(maxDigits) 'd'],...
          1:nDim),1,zeros(1,nDim)+4+maxDigits);
        if nargin < 3
          plotOpts = struct('leftBase',.05,'rightBase',.02,...
            'bottomBase',.05,'topBase',.02,'FontSize',10,...
            'widthUsableArea',.87,'heigthUsableArea',.87);
          if nargin < 2
            nClusters = 7;
            if nargin < 1
              center1 = [1; 0; 0; 0; 0];
              center2 = [0; 1; 0; 0; 0];
              center3 = [0; 0; 1; 0; 0];
              center4 = [0; 0; 0; 1; 0];
              center5 = [0; 0; 0; 0; 1];
              center6 = [0; 0; 0; 0; 1.5];
              center7 = [0; 0; 0; 1.5; 1];
              data = [...
                      bsxfun(@plus,center1,.5*rand(5,20)) ...
                      bsxfun(@plus,center2,.5*rand(5,20)) ...
                      bsxfun(@plus,center3,.5*rand(5,20)) ...
                      bsxfun(@plus,center4,.5*rand(5,20)) ...
                      bsxfun(@plus,center5,.5*rand(5,20)) ...
                      bsxfun(@plus,center6,.2*rand(5,20)) ...
                      bsxfun(@plus,center7,.2*rand(5,20)) ...
                     ];
            end
          end
        end
      end
    end
  end
end

% NOTE of advice: It seems that Matlab does not test while on
% refreshdata if the dimension of the inputs are equivalent for the
% XData, YData and CData while using scatter. Because of this I wasted
% a lot of time trying to debug what was the problem, trying many
% workaround since my cluster centers would disappear for no reason.

% Draw axes:
nDim = size(data,1);

figH = figure;
set(figH,'Units', 'normalized', 'Position',...
  [0, 0, 1, 1],'Color','w','Name',...
  'k-means example','NumberTitle','Off',...
  'MenuBar','none','Toolbar','figure',...
  'Renderer','zbuffer');

% Create dintinguishable colors matrix:
colorMatrix = distinguishable_colors(nClusters);

% Create axes, deploy them on handles matrix more or less how they
% will be positioned:
[handles,horSpace,vertSpace] = ...
  createAxesGrid(5,5,plotOpts,dimLabels);

% Add main axes
bigSubSize = ceil(nDim/2);
bigSubVec(bigSubSize^2) = 0;
for k = 0:nDim-bigSubSize
  bigSubVec(k*bigSubSize+1:(k+1)*bigSubSize) = ...
    ... %(nDim-bigSubSize+k)*nDim+1:(nDim-bigSubSize+k)*nDim+(nDim-bigSubSize+1);
    bigSubSize+nDim*k:nDim*(k+1);
end

handles(bigSubSize,bigSubSize) = subplot(nDim,nDim,bigSubVec,...
  'FontSize',plotOpts.FontSize,'box','on'); 
bigSubplotH = handles(bigSubSize,bigSubSize);
horSpace(bigSubSize,bigSubSize) = bigSubSize;
vertSpace(bigSubSize,bigSubSize) = bigSubSize;
set(bigSubplotH,'NextPlot','add',...
  'FontSize',plotOpts.FontSize,'box','on',...
  'XAxisLocation','top','YAxisLocation','right');

% Squeeze axes through space to optimize space usage and improve
% visualization capability:
[leftPos,botPos,subplotWidth,subplotHeight]=setCustomPlotArea(...
  handles,plotOpts,horSpace,vertSpace);

pColorAxes = axes('Position',[leftPos(end) botPos(end) ...
  subplotWidth subplotHeight],'Parent',figH);
pcolor([1:nClusters+1;1:nClusters+1])
% image(reshape(colorMatrix,[1 size(colorMatrix)])); % Used image to
% check if the upcoming buggy behaviour would be fixed. I was not
% lucky, though...
colormap(pColorAxes,colorMatrix);
% Change XTick positions to its center:
set(pColorAxes,'XTick',.5:1:nClusters+.5);
set(pColorAxes,'YTick',[]);
% Change its label to cluster number:
set(pColorAxes,'XTickLabel',[nClusters 1:nClusters-1]); % FIXME At
% least on my matlab I have to use this buggy way to set XTickLabel.
% Am I doing something wrong? Since I dunno why this is caused, I just
% change the code so that it looks the way it should look, but this is
% quite strange...
xlabel(pColorAxes,'Clusters Colors','FontSize',plotOpts.FontSize);

% Now iterate throw data and get cluster information:
[label,m]=litekmeans(data,nClusters);

nIters = numel(m)-1;

scatterColors = colorMatrix(label{1},:);

annH = annotation('textbox',[leftPos(1),botPos(1) subplotWidth ...
  subplotHeight],'String',sprintf('Start Conditions'),'EdgeColor',...
  'none','FontSize',18);

% Creates dimData_%d variables for first iteration:
for curDim=1:nDim
  curDimVarName = genvarname(sprintf('dimData_%d',curDim));
  eval([curDimVarName,'= m{1}(curDim,:);']);
end

%   clusterColors will hold the colors for the total number of clusters
% on each iteration:
clusterColors = colorMatrix;

% Draw cluster information for first iteration:
for curColumn=1:nDim
  for curLine=curColumn+1:nDim
    % Big subplot data:
    if curColumn == bigXDim && curLine == bigYDim
      curAxes = handles(bigSubSize,bigSubSize);
      curScatter = scatter(curAxes,data(curColumn,:),...
        data(curLine,:),16,'filled');
      set(curScatter,'CDataSource','scatterColors');
      % Draw cluster centers 
      curColumnVarName = genvarname(sprintf('dimData_%d',curColumn));
      curLineVarName = genvarname(sprintf('dimData_%d',curLine));
      eval(['curScatter=scatter(curAxes,' curColumnVarName ',' ... 
        curLineVarName ',100,colorMatrix,''^'',''filled'');']);
      set(curScatter,'XDataSource',curColumnVarName,'YDataSource',...
        curLineVarName,'CDataSource','clusterColors')
    end
    % Small subplots data:
    curAxes = handles(curLine,curColumn);
    % Draw data:
    curScatter = scatter(curAxes,data(curColumn,:),...
      data(curLine,:),16,'filled');
    set(curScatter,'CDataSource','scatterColors');
    % Draw cluster centers 
    curColumnVarName = genvarname(sprintf('dimData_%d',curColumn));
    curLineVarName = genvarname(sprintf('dimData_%d',curLine));
    eval(['curScatter=scatter(curAxes,' curColumnVarName ',' ... 
      curLineVarName ',100,colorMatrix,''^'',''filled'');']);
    set(curScatter,'XDataSource',curColumnVarName,'YDataSource',...
      curLineVarName,'CDataSource','clusterColors');
    if curLine==nDim
      xlabel(curAxes,dimLabels{curColumn});
      set(curAxes,'XTick',xlim(curAxes));
    end
    if curColumn==1
      ylabel(curAxes,dimLabels{curLine});
      set(curAxes,'YTick',ylim(curAxes));
    end 
  end
end

refreshdata(figH,'caller');

% Preallocate gif frame. From Amro's tutorial here:
% https://stackoverflow.com/a/11054155/1162884
f = getframe(figH);
[f,map] = rgb2ind(f.cdata, 256, 'nodither');
mov = repmat(f, [1 1 1 nIters+4]);

% Add one frame at start conditions:
curFrame = 1;
% Add three frames without movement at start conditions
f = getframe(figH);
mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');

for curIter = 1:nIters
  curFrame = curFrame+1;
  % Change label text
  set(annH,'String',sprintf('Iteration %d',curIter));
  % Update cluster point colors
  scatterColors = colorMatrix(label{curIter+1},:);
  % Update cluster centers:
  for curDim=1:nDim
    curDimVarName = genvarname(sprintf('dimData_%d',curDim));
    eval([curDimVarName,'= m{curIter+1}(curDim,:);']);
  end
  % Update cluster colors:
  nClusterIter = size(m{curIter+1},2);
  clusterColors = colorMatrix(1:nClusterIter,:);
  % Update graphics:
  refreshdata(figH,'caller');
  % Update cluster colors:
  if nClusterIter~=size(m{curIter},2) % If number of cluster
    % of current iteration differs from previous iteration (or start
    % conditions in case we are at first iteration) we redraw colors: 
    pcolor([1:nClusterIter+1;1:nClusterIter+1])
    % image(reshape(colorMatrix,[1 size(colorMatrix)])); % Used image to
    % check if the upcomming buggy behaviour would be fixed. I was not
    % lucky, though...
    colormap(pColorAxes,clusterColors);
    % Change XTick positions to its center:
    set(pColorAxes,'XTick',.5:1:nClusterIter+.5);
    set(pColorAxes,'YTick',[]);
    % Change its label to cluster number:
    set(pColorAxes,'XTickLabel',[nClusterIter 1:nClusterIter-1]); 
    xlabel(pColorAxes,'Clusters Colors','FontSize',plotOpts.FontSize);
  end
  f = getframe(figH);
  mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');
end

set(annH,'String','Convergence Conditions');

for curFrame = nIters+1:nIters+3
  % Add three frames without movement at start conditions
  f = getframe(figH);
  mov(:,:,1,curFrame) = rgb2ind(f.cdata, map, 'nodither');
end

imwrite(mov, map, gifName, 'DelayTime',.5, 'LoopCount',inf)

varargout = cell(1,nargout);

if nargout > 0
  varargout{1} = label;
  if nargout > 1
    varargout{2} = m;
    if nargout > 2
      varargout{3} = figH;
      if nargout > 3
        varargout{4} = handles;
      end
    end
  end
end

end


function [leftPos,botPos,subplotWidth,subplotHeight] = ...
  setCustomPlotArea(handles,plotOpts,horSpace,vertSpace)
%
% -> handles: axes handles
%
% -> plotOpts: struct holding the following fields:
%
%   o leftBase: the percentage distance from the left
%
%   o rightBase: the percentage distance from the right
%
%   o bottomBase: the percentage distance from the bottom
%
%   o topBase: the percentage distance from the top
%
%   o widthUsableArea: Total width occupied by axes
%
%   o heigthUsableArea: Total heigth occupied by axes
%
% -> horSpace: the axes units size (integers only) that current axes
% should occupy in the horizontal (considering that other occupied
% axes handles are empty)
%
% -> vertSpace: the axes units size (integers only) that current axes
% should occupy in the vertical (considering that other occupied
% axes handles are empty)
%

nHorSubPlot =  size(handles,1);
nVertSubPlot = size(handles,2);

if nargin < 4
  horSpace(nHorSubPlot,nVertSubPlot) = 0;
  horSpace = horSpace+1;
  if nargin < 3
    vertSpace(nHorSubPlot,nVertSubPlot) = 0;
    vertSpace = vertSpace+1;
  end
end

subplotWidth = plotOpts.widthUsableArea/nHorSubPlot;
subplotHeight = plotOpts.heigthUsableArea/nVertSubPlot;

totalWidth = (1-plotOpts.rightBase) - plotOpts.leftBase;
totalHeight = (1-plotOpts.topBase) - plotOpts.bottomBase;

gapHeigthSpace = (totalHeight - ...
  plotOpts.heigthUsableArea)/(nVertSubPlot);
gapWidthSpace = (totalWidth - ...
  plotOpts.widthUsableArea)/(nHorSubPlot);

botPos(nVertSubPlot) = plotOpts.bottomBase + gapWidthSpace/2;
leftPos(1) = plotOpts.leftBase + gapHeigthSpace/2;

botPos(nVertSubPlot-1:-1:1) = botPos(nVertSubPlot) + (subplotHeight +...
  gapHeigthSpace)*(1:nVertSubPlot-1);
leftPos(2:nHorSubPlot) = leftPos(1) + (subplotWidth +...
  gapWidthSpace)*(1:nHorSubPlot-1);

for curLine=1:nHorSubPlot
  for curColumn=1:nVertSubPlot
    if handles(curLine,curColumn)
      set(handles(curLine,curColumn),'Position',[leftPos(curColumn)...
        botPos(curLine) horSpace(curLine,curColumn)*subplotWidth ...             
        vertSpace(curLine,curColumn)*subplotHeight]);                     
    end
  end                                                         
end                                                           

end


function [handles,horSpace,vertSpace] = ...
  createAxesGrid(nLines,nColumns,plotOpts,dimLabels)

handles = zeros(nLines,nColumns);

% Those hold the axes size units:
horSpace(nLines,nColumns) = 0;
vertSpace(nLines,nColumns) = 0;

for curColumn=1:nColumns
  for curLine=curColumn+1:nLines
    handles(curLine,curColumn) = subplot(nLines,...
      nColumns,curColumn+(curLine-1)*nColumns);
    horSpace(curLine,curColumn) = 1;
    vertSpace(curLine,curColumn) = 1;
    curAxes = handles(curLine,curColumn);
    if feature('UseHG2')
      colormap(handle(curAxes),colorMatrix);
    end
    set(curAxes,'NextPlot','add',...
      'FontSize',plotOpts.FontSize,'box','on'); 
    if curLine==nLines
      xlabel(curAxes,dimLabels{curColumn});
    else
      set(curAxes,'XTick',[]);
    end
    if curColumn==1
      ylabel(curAxes,dimLabels{curLine});
    else
      set(curAxes,'YTick',[]);
    end
  end
end
end

示例

以下是使用代码:

使用5个维度的示例
center1 = [1; 0; 0; 0; 0];
center2 = [0; 1; 0; 0; 0];
center3 = [0; 0; 1; 0; 0];
center4 = [0; 0; 0; 1; 0];
center5 = [0; 0; 0; 0; 1];
center6 = [0; 0; 0; 0; 1.5];
center7 = [0; 0; 0; 1.5; 1];
data = [...
        bsxfun(@plus,center1,.5*rand(5,20)) ...
        bsxfun(@plus,center2,.5*rand(5,20)) ...
        bsxfun(@plus,center3,.5*rand(5,20)) ...
        bsxfun(@plus,center4,.5*rand(5,20)) ...
        bsxfun(@plus,center5,.5*rand(5,20)) ...
        bsxfun(@plus,center6,.2*rand(5,20)) ...
        bsxfun(@plus,center7,.2*rand(5,20)) ...
       ];
[label,m,figH,handles]=kmeans_test(data,20);

enter image description here