我使用SVM进行分类,我在一个项目中进行培训,在另一个项目中进行测试,以便只训练一次。
培训部分如下:
classifier->trainAuto(trainData);
string svmDir = "/File/Dir/";
string svmFile = "svmClassifier.xml";
classifier->save(svmDir+svmFile);
TESTing部分是:
string svmDir = "/File/Dir/";
string svmFile = "svmClassifier.xml";
Ptr<ml::SVM> classifier = ml::SVM::load<ml::SVM>(svmDir+svmFile);
...
float response = classifier->predict(tDescriptor);
预测给出全0(全部为负)。但是当我在训练项目中进行SVM训练后立即进行预测时,预测是正确的(我在#34之前使用断点;预测&#34;,传递给预测的tDescriptor在两个项目中是相同的。) 所以我认为保存和加载过程可能有问题。
是否可以保存并加载自动训练的SVM?或者它必须在statModel?
感谢您的帮助!
答案 0 :(得分:0)
流动的代码是从Introduction to Support Vector Machines for open CV获取和修改的。我从对象svmOld“trainedSVM.xml”中保存了SVM参数。然后加载XML文件并使用它们来创建对象svmNew。
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include <opencv2/highgui.hpp>
#include <opencv2/ml.hpp>
using namespace cv;
using namespace cv::ml;
using namespace std;
int main()
{
// Data for visual representation
int width = 512, height = 512;
Mat image = Mat::zeros(height, width, CV_8UC3);
// Set up training data
int Lable[] = { 1, -1, -1, -1 };
Mat labelsMat(4, 1, CV_32S, Lable);
float trainingData[4][2] = { { 501, 10 },{ 255, 10 },{ 501, 255 },{ 10, 501 } };
Mat trainingDataMat(4, 2, CV_32FC1, trainingData);
// Set up SVM's parameters
Ptr<SVM> svmOld = SVM::create();
svmOld->setType(SVM::C_SVC);
svmOld->setKernel(SVM::LINEAR);
svmOld->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
// Train the SVM with given parameters
Ptr<TrainData> td = TrainData::create(trainingDataMat, ROW_SAMPLE, labelsMat);
svmOld->train(td);
//same svm
svmOld->save("trainedSVM.xml");
//Initialize SVM object
Ptr<SVM> svmNew = SVM::create();
//Load Previously saved SVM from XML
svmNew = SVM::load<SVM>("trainedSVM.xml");
Vec3b green(0, 255, 0), blue(255, 0, 0);
// Show the decision regions given by the SVM
for (int i = 0; i < image.rows; ++i)
for (int j = 0; j < image.cols; ++j)
{
Mat sampleMat = (Mat_<float>(1, 2) << j, i);
float response = svmNew->predict(sampleMat);
if (response == 1)
image.at<Vec3b>(i, j) = green;
else if (response == -1)
image.at<Vec3b>(i, j) = blue;
}
// Show the training data
int thickness = -1;
int lineType = 8;
circle(image, Point(501, 10), 5, Scalar(0, 0, 0), thickness, lineType);
circle(image, Point(255, 10), 5, Scalar(255, 255, 255), thickness, lineType);
circle(image, Point(501, 255), 5, Scalar(255, 255, 255), thickness, lineType);
circle(image, Point(10, 501), 5, Scalar(255, 255, 255), thickness, lineType);
// Show support vectors
thickness = 2;
lineType = 8;
Mat sv = svmNew->getSupportVectors();
for (int i = 0; i < sv.rows; ++i)
{
const float* v = sv.ptr<float>(i);
circle(image, Point((int)v[0], (int)v[1]), 6, Scalar(128, 128, 128), thickness, lineType);
}
imwrite("result.png", image); // save the image
imshow("SVM Simple Example", image); // show it to the user
waitKey(0);
return(0);
}