在MATLAB上训练CNN时如何设置验证集(trainNetwork())

时间:2017-07-23 17:05:14

标签: matlab machine-learning neural-network deep-learning virtual-machine

我正在尝试在MATLAB上训练CNN。 matlab文档说,加载数据,设置图层和选项。最后使用trainNetwork()进行训练。

    layers = [imageInputLayer([28 28 1])
          convolution2dLayer(5,10,...
                                'Stride',1,...
                                'Padding',[0,0])
          reluLayer
          maxPooling2dLayer(2,'Stride',2)
          fullyConnectedLayer(10)
          softmaxLayer
          classificationLayer];


 options = trainingOptions('sgdm',...Environment
                            'CheckpointPath','',...
                            'ExecutionEnvironment','gpu',...                'auto'  | 'cpu' | 'gpu' | 'multi-gpu' | 'parallel'
                            'InitialLearnRate',0.0001,...   Learning Rate
                            'LearnRateSchedule','none',...                  none    |piecewise
                            'LearnRateDropPeriod',10,...
                            'LearnRateDropFactor',0.1,...
                            'L2Regularization',0.0001,...   Regularization
                            'MaxEpochs',15,...              Epochs
                            'MiniBatchSize',128,...         Batch           128     |
                            'Momentum',0.9,...                              0.9     |
                            'Shuffle','once',...                            once    |never
                            'Verbose',1,...                                 1       | 0             — Indicator to display the information on the training progress
                            'VerboseFrequency',100,...                      50      | 0 
                            'OutputFcn',@plotTrainingAccuracy);

convnet = trainNetwork(trainDigitData,layers,options);

下面是我培训CNN的程序,但问题是我找不到设置验证集的选项。我设置'时代'的数字越大,它训练的时间就越长。它会在过度拟合之前停止吗?

不喜欢nnstart工具箱,当训练NN时,它会显示交叉熵和验证,训练错误率。

那么,在matlab上训练CNN时你通常会使用什么?使用像caffe这样的第三方lib界面?或自己编写程序?

1 个答案:

答案 0 :(得分:0)

您可以将数据拆分为训练和测试数据

idx = floor(0.8 * height(data));
trainingData = data(1:idx,:);
testData = data(idx:end,:);

然后在trainNetwork之后,可以运行测试部分

resultsStruct = struct([]);

for i = 1:height(testData)

    % Read the image.
    I = imread(testData.imageFilename{i});
    % Run the detector.
    [bboxes, scores, labels] = detect(detector, I);

    % Collect the results.
    resultsStruct(i).Boxes = bboxes;
    resultsStruct(i).Scores = scores;
    resultsStruct(i).Labels = labels;
end

% Convert the results into a table.
results = struct2table(resultsStruct);

如果您要检查实施以实现更快的R-CNN实现https://www.mathworks.com/help/vision/examples/object-detection-using-faster-r-cnn-deep-learning.html