MATLAB中的K-Cross交叉验证

时间:2017-01-24 23:03:09

标签: matlab

我有36位用户的数据

如代码(MATLAB)所示,数据被加载到一个名为 Feat_Vec1

的矩阵中
 clear;
  for nc = 1 : 36  % nc number of users
              % Load data into MATLAB
      data{nc} = load( sprintf('U%02d_Acc_TimeD_FreqD_FDay.mat', nc) );  

             % assign data into Matrix
      Feat_Vec1{nc} = data{nc}.Acc_TD_Feat_Vec(:,:);   

 end

对于每个用户,我有36行和143列, Feat_Vec1 包含 36个单元格(用户数),每个单元格包含 36行和143列

我想使用 9-fold交叉验证,以便将我的数据集划分为培训和测试。

我已经看到documentation in MATLAB help,但不明白! 想知道是否有人帮助我为每个用户编写9折交叉验证代码?

3 个答案:

答案 0 :(得分:1)

继续你的代码:

% Combine data into a single matrix:
data = []; % sorry for the dynamic allocation
for i=1:36
    data = [data; Feat_Vec1{nc}];
end

N = size(data, 1); % should be 36^2 = 1296
K = 9;
% create a vector that have K (=9) blocks of length N/K, such as [1 1 ... 1 2 2 ... 2 ... 9 9 ... 9 ]
folds = []; % again, sorry for the dynamic allocation
for i=1:K
    folds = [folds; repmat(i, N/K, 1)];
end

现在您确定了折叠指数。此时不需要将数据收集到单个矩阵中,因为我们已经知道N。但是,在for循环中使用数据进行训练和测试时,此变量可能很有用:

accuracies = zeros(1, K);
for fold = 1:K
   testIds = find(folds==fold);
   trainIds = find(folds~=fold);
   % train your algorithm
   model = train(data(trainIds,:), label(trainIds,:), etc);
   % evaluate on the testing fold
   accuracies(fold) = test(model, data(testIds,:), label(testIds,:), etc);
end
mean(accuracy)

希望有所帮助

答案 1 :(得分:1)

我将修改您的代码,以显示如何为每个用户独立完成9倍交叉验证。这意味着,每个用户都将拥有自己的列车测试折叠。首先,9倍交叉验证意味着用户获得8/9的训练数据和1/9的测试数据。重复九次。

clear;
for nc = 1:36  % nc number of users
    % probably you don't need to save data in a cell - like data{nc}
    data = load( sprintf('U%02d_Acc_TimeD_FreqD_FDay.mat', nc));
    data = data.Acc_TD_Feat_Vec(:,:);
    ind = crossvalind('Kfold', 36, 9);
    for fold = 1:9
        Feat_Vec1_train{nc}{fold} = data(ind ~= fold, :)  
        Feat_Vec1_test{nc}{fold} = data(ind == fold, :)    
    end
end

在上面的代码中,每个用户有9对列车和测试集。例如,第3位用户的第8次列车 - 测试对可以被访问为:

Feat_Vec1_train{3}{8}
Feat_Vec1_test{3}{8}

答案 2 :(得分:1)

对上面的代码稍作修改

clear;
for nc = 1:36  % nc number of users
    % probably you don't need to save data in a cell - like data{nc}
    data = load( sprintf('U%02d_Acc_TimeD_FreqD_FDay.mat', nc));
    data = data.Acc_TD_Feat_Vec(:,:);
    ind = crossvalind('Kfold', 36, 9);
    for fold = 1:9
        Feat_Vec1_train{nc}{fold} = data(ind ~= fold, :)  
        Feat_Vec1_test{nc}{fold} = data(ind == fold, :)    
    end
end