opencv 2.3.1使用bow,svm训练图像和加载/保存功能

时间:2015-04-20 06:51:20

标签: c++ xml opencv machine-learning svm

我是opencv中的新手,并尝试使用opencv2.3.1对两个图像类别进行分类。 这是我的代码。

void trainSVM(map<string,Mat>& classes_training_data, string& file_postfix, int response_cols, int response_type) {

//train 1-vs-all SVMs
vector<string> classes_names;
for (map<string,Mat>::iterator it = classes_training_data.begin(); it != classes_training_data.end(); ++it) {
    classes_names.push_back((*it).first);
}

string use_postfix = file_postfix;
for (int i=0;i<classes_names.size();i++) {
    string class_ = classes_names[i];

    Mat samples(0,response_cols,response_type);
    Mat labels(0,1,CV_32FC1);

    //copy class samples and label
    cout << "adding " << classes_training_data[class_].rows << " positive" << endl;
    samples.push_back(classes_training_data[class_]);
    Mat class_label = Mat::ones(classes_training_data[class_].rows, 1, CV_32FC1);
    labels.push_back(class_label);

    //copy rest samples and label
    for (map<string,Mat>::iterator it1 = classes_training_data.begin(); it1 != classes_training_data.end(); ++it1) {
        string not_class_ = (*it1).first;
        if(not_class_.compare(class_)==0) continue;
        samples.push_back(classes_training_data[not_class_]);
        class_label = Mat::zeros(classes_training_data[not_class_].rows, 1, CV_32FC1);
        labels.push_back(class_label);
    }

    cout << "Train.." << endl;
    Mat samples_32f; samples.convertTo(samples_32f, CV_32F);
    if(samples.rows == 0) continue; //phantom class?!
    CvSVM classifier; 
    classifier.train(samples_32f,labels);

    {
        stringstream ss; 
        ss << "SVM_classifier_"; 
        if(file_postfix.size() > 0) ss << file_postfix << "_";
        ss << class_ << ".yml";
        cout << "Save.." << endl;
        classifier.save(ss.str().c_str());
    }
}
}

我已成功保存了火车文件。当我尝试使用以下代码片段加载训练文件时:

classes_classifiers[catefory[i]].load(fclass.c_str());

它正常运行。 svm.get_support_vector_count()

它也是如此。 但是当添加时 svm.predict(descriptors,false);

它会报告错误

"OpenCV Error: Bad argument (The sample is not a valid vector) in cvPreparePredictData, file ~/OpenCV-2.3.1/modules/ml/src/inner_functions.cpp, line 1099
terminate called after throwing an instance of 'cv::Exception'
  what():  ~/modules/ml/src/inner_functions.cpp:1099: error: (-5) The sample is not a valid vector in function cvPreparePredictData"

有没有人可以帮我解决这个问题?

问候。

2 个答案:

答案 0 :(得分:0)

您好预测变量descriptors可能不是矢量,通常descriptors与您的一列火车样本数据具有相同的大小。 这是我的opencv SVM的示例代码:

int sample_num = 100;
int feature_size = 256;
Mat_<float> train_data(sample_num,feature_size);
//put your train sample datas
...
Mat_<int> label_data(sample_num,1);
//put your train sample label
TermCriteria criteria( CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
SVMParams param( SVM::C_SVC, SVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria );
SVM svm;
//svm training
svm.train(train_data, label_data, Mat(), Mat(), param);
svm.save("svmtest.xml");
SVM _svm;
_svm.load("svmtest.xml");
Mat_<float> test_data(1,feature_size);
//put your test data
...
int predict_label = _svm.predict(test_data);

答案 1 :(得分:0)

我认为你缺少词汇量。在加载保存在.yml文件中的训练数据之前,你必须保存你的词汇:当你使用BOW描述符(BOWImgDescriptorExtractor&amp; BOWKMeansTrainer)时,你可以加载它在另一个项目中,用它计算测试图像的新描述符,并使用保存的yml文件进行训练。希望这可以帮到你

相关问题