如何使用多类SVM实现k-fold交叉验证

时间:2015-10-18 05:05:15

标签: matlab svm cross-validation

我正在研究年龄预测主题,我可以通过列车实施11个班级的多班级SVM,每个班级都有积极与所有其余的herehere。 但问题出在for循环中,如下图所示,训练数据需要11个循环:

for k = 1:numClasses
    %Vectorized statement that binarizes Group
    %where 1 is the current class and 0 is all other classes
    G_x_All = (train_label == u(k));
    G_x_All = double (G_x_All);
    SVMStruct{k} = svmtrain(Data_Set, G_x_All);
end

然后,每个图像的数据分类还需要11个循环:

for j = 1:total_images
  for k = 1:numClasses
      if(svmclassify(SVMStruct{k}, Test_Img(j,:)));
          break;
      end
  end
  Age (j) = u(k); % Put the number of correct class in Age vector
end

我的noob问题是,如何在所有这些循环之后制作k-fold cross validation

编辑::>

根据zelanix先生的建议,这是代码的最后一次更新,但结果不好。你能帮助我提高它的表现吗?

u = unique(train_label);
numClasses = length (u);
N = size (Data_Set,1)
A = 10;
indices = crossvalind('Kfold', N, A);
cp = classperf (train_label);

for i = 1:A
    Test = (indices == i); 
    Train = ~Test;         
    SVMStruct = cell(numClasses, 1); % Clear data structure.

    % Build models
  for k = 1:numClasses
    %Vectorized statement that binarizes Group
    %where 1 is the current class and 0 is all other classes
    G_x_All = (train_label == u(k));
    G_x_All = double (G_x_All);
    SVMStruct{k} = svmtrain(Data_Set (Train,:), G_x_All(Train,:));
  end

  Age = NaN(size(Data_Set, 1), 1);

  % Classify test cases
  for k = 1:numClasses
      if(svmclassify(SVMStruct{k}, Data_Set(Test,:)));
          break;
      end
  end
   Age = u(k);
   if Age == 1
       disp ('Under 10 years old');
   elseif Age == 10
       disp ('Between 10 and 20 years old');
   elseif Age == 20
       disp ('Between 20 and 30 years old');
   elseif Age == 30
       disp ('Between 30 and 40 years old');
   elseif Age == 40
       disp ('Between 40 and 50 years old');
   elseif Age == 50
       disp ('Between 50 and 60 years old');
   elseif Age == 60
       disp ('Upper 60 years old');
   else
       disp ('Unknown');
   end

classperf(cp, Age, Test);
    disp (i)
end
cp.CorrectRate

请注意,我将标签数量减少到7而不是11。

1 个答案:

答案 0 :(得分:1)

您需要的一般结构如下(假设您的数据位于变量your_data中,其大小为N x M,其中N是样本数量, M是功能的数量,您的类标签的大小为your_classes x 1的变量M

K = 10; % The number of folds
N = size(your_data, 1); % The number of data samples to train / test
idx = crossvalind('Kfold', N, K)

% your_classes should contain the class between 1 and numClasses.
cp = classperf(your_classes);

for i = 1:K
    Data_Set = your_data(idx ~= i, :); % The data to train on, 90% of the total.
    train_label = your_classes(idx ~= i, :); % The class labels of your training data.
    Test_Img = your_data(idx == i, :); % The data to test on, 10% of the total.
    test_label = your_classes(idx == i, :); % The class labels of your test data.

    SVMStruct = cell(numClasses, 1); % Clear data structure.

    % Your training routine, copied verbatim
    for k = 1:numClasses
        %Vectorized statement that binarizes Group
        %where 1 is the current class and 0 is all other classes
        G_x_All = (train_label == u(k));
        G_x_All = double (G_x_All);
        SVMStruct{k} = svmtrain(Data_Set, G_x_All);
    end

    Age = NaN(size(Test_Img, 1), 1);

    % Your test routine, copied (almost) verbatim
    for j = 1:size(Test_Img, 1)
      for k = 1:numClasses
          if(svmclassify(SVMStruct{k}, Test_Img(j,:)));
              break;
          end
      end
      Age(j) = u(k); % Put the number of correct class in Age vector
    end

    cp = classperf(cp, Age, idx == i);
end

cp.CorrectRate

这是未经测试的,我不确定您的分类是如何实际运作的。你似乎打破了匹配的第一个分类,这可能不是正确的分类,或者确实是最可能的分类。您还需要一些方法来记录它并将其与test_label中的真实类标签相匹配。我建议您查看classperf函数,但这是一个单独的问题。

另请注意,matlab在fitcecoc函数中内置了多类SVM分类,可能更适合您的需求。

编辑您的更新代码问题(如上所述)与您的分类方法有关。您循环并测试样本是否属于每个类,并在匹配的第一个类上中断。这不太可能是最可能的分类,因此我对您的结果不佳感到惊讶。

您的样本可能会与第一个模型匹配,但只是一小部分匹配,但它没有达到与之匹配的模型。它不太可能超过前几个分类,如果它到达最后并且没有任何分类器匹配,你会怎么做?

使用SVM的多类分类通常是通过选择样本被分类的类别最高的可能性来实现的(所有类别中测试为正的类别,选择离决策边界最远的类别 - fitcecoc在内部执行此操作) 。但是使用svmclassify无法手动执行此操作,因为它无法让您访问这些详细信息。

fitcecoc仍然无法访问您可能需要的所有值,因此,如果您确实要手动执行此操作,那么我建议您查看libsvm,否则请使用{{ 1}}。

正如评论中所提到的,fitcecocsvmtrain现在已被弃用 - libsvm还提供了使用内置MATLAB实现无法实现调优和性能的更大可能性。

顺便说一句,多类逻辑回归理解起来要简单得多,也可以取得很好的效果。