Matlab调试 - 初学者级别

时间:2011-09-17 05:15:00

标签: matlab machine-learning

我是Matlab的初学者,并试图在Matlab中编写一些机器学习算法。如果有人可以帮我调试这段代码,我真的很感激。

function y = KNNpredict(trX,trY,K,X)
   % trX is NxD, trY is Nx1, K is 1x1 and X is 1xD
   % we return a single value 'y' which is the predicted class

% TODO: write this function
% int[] distance = new int[N];
distances = zeroes(N, 1);
examples = zeroes(K, D+2);
i = 0;
% for(every row in trX) { // taking ONE example
for row=1:N, 
 examples(row,:) = trX(row,:);
 %sum = 0.0;
 %for(every col in this example) { // taking every feature of this example
 for col=1:D, 
    % diff = compute squared difference between these points - (trX[row][col]-X[col])^2
    diff =(trX(row,col)-X(col))^2;
    sum += diff;
 end % for
 distances(row) = sqrt(sum);
 examples(i:D+1) = distances(row);
 examples(i:D+2) = trY(row:1);
end % for

% sort the examples based on their distances thus calculated
sortrows(examples, D+1);
% for(int i = 0; i < K; K++) {
% These are the nearest neighbors
pos = 0;
neg = 0;
res = 0;
for row=1:K,
    if(examples(row,D+2 == -1))
        neg = neg + 1;
    else
        pos = pos + 1;
    %disp(distances(row));
    end
end % for

if(pos > neg)
    y = 1;
    return;
else
    y = -1;
    return;
end
end
end

非常感谢

1 个答案:

答案 0 :(得分:2)

在MATLAB中使用矩阵时,通常最好避免过多的循环,而是尽可能使用矢量化操作。这通常会产生更快更短的代码。

在你的情况下,k-最近邻算法很简单,可以很好地矢量化。请考虑以下实现:

function y = KNNpredict(trX, trY, K, x)
    %# euclidean distance between instance x and every training instance
    dist = sqrt( sum( bsxfun(@minus, trX, x).^2 , 2) );

    %# sorting indices from smaller to larger distances
    [~,ord] = sort(dist, 'ascend');

    %# get the labels of the K nearest neighbors
    kTrY = trY( ord(1:min(K,end)) );

    %# majority class vote
    y = mode(kTrY);
end

以下是使用Fisher-Iris数据集测试它的示例:

%# load dataset (data + labels)
load fisheriris
X = meas;
Y = grp2idx(species);

%# partition the data into training/testing
c = cvpartition(Y, 'holdout',1/3);
trX = X(c.training,:);
trY = Y(c.training);
tsX = X(c.test,:);
tsY = Y(c.test);

%# prediction
K = 10;
pred = zeros(c.TestSize,1);
for i=1:c.TestSize
    pred(i) = KNNpredict(trX, trY, K, tsX(i,:));
end

%# validation
C = confusionmat(tsY, pred)

kNN预测的混淆矩阵,K = 10:

C =
    17     0     0
     0    16     0
     0     1    16