MATLAB函数矩阵参数

时间:2011-01-30 03:17:18

标签: matlab

我见过关于计算K最近邻居的blog post,如下所示:

function test_targets = knn(train_patterns, train_targets, test_patterns, K)
    % Hubungi budi santosa di budi_s@ie.its.ac.id
    % untuk laporan kesalahan (bug).
    % Implementasi the Nearest neighbor algorithm
    % Inputs:
    %  train_patterns  - Train patterns (obs x dim) D x N
    %  train_targets   - Train targets              1 x N (classes)
    %  test_patterns   - Test  patterns             D x M (M testing)  
    %  K               - jumlah nearest neighbors
    %
    % Outputs
    % test_targets - Predicted targets

    L   = length(train_targets);
    Uc          = unique(train_targets);

    if (L < K),
       error(’tetangga lebih banyak dari jumlah titik training’)
    end

    N               = size(test_patterns, 1);
    test_targets    = zeros(N,1);
    for i = 1:N,
        jar=(train_patterns - repmat(test_patterns(i,:),L,1)).^2;
        dist            = sum(jar,2);%jarak tiap titik data test terhadap data training
        [m, indices]    = sort(dist);%urutkan jarak dr yg terkecil
        yt=train_targets(indices(1:K));%ambil K jarak terkecil dan periksa labelnya
         n               = hist(yt, Uc);%menempatkan data testing ke kelas mana (tergantung Uc)

        [m, best]       = max(n);%mencari frekuensi maksimum kelas mana paling banyak dari K tetangga terdekat

        test_targets(i) = Uc(best);
    end

我的问题是我不断收到以下MATLAB消息:

??? Error using ==> minus
Matrix dimensions must agree.

我有两个矩阵:

A is NxD A = 
670.00 1630.00 2380.00 1
721.00 1680.00 2400.00 1
750.00 1710.00 2440.00 1
660.00 1800.00 2150.00 1
660.00 1800.00 2150.00 1
680.00 1958.00 2542.00 1
440.00 1120.00 2210.00 2
400.00 1070.00 2280.00 2

B is MxD B =
750.00 1710.00 2440.00 1
680.00 1910.00 2440.00 1
500.00 1000.00 2325.00 2
500.00 1000.00 2325.00 2

如您所见,第4列说明了示例的类。我使用的功能如下:

train_patterns  = A(:,:)     %HOW TO PASS A??, A(:,1:3)? A(1:size(B,1),:) ??  which????   
train_targets   = A(:,4)     %pass the column 4 as vector of classes 
test_patterns   = B(:,1:3)   %pass only the 3 columns
Knn             = 3

因此输出必须是向量1 x M,并预测所有B示例。我怎么能做到这一点?

1 个答案:

答案 0 :(得分:1)

您需要transpose A和B从NxD转到DxN(使用'运算符)。

因此:

train_patterns = A(:,1:3)'; %'# 3-by-N
train_targets = A(:,4)'; %'# 1-by-N
test_patterns = B(:,1:3)'; %'# 3-by-M (last column will be used by you for checking)