我正在尝试使用OpenCV 3.1.0来训练MNIST数据集的NB分类器。我使用来自http://pjreddie.com/projects/mnist-in-csv/的准备好的CSV培训和数据文件来训练NB分类器。我使用剪切和粘贴稍微修改了这个CSV文件,以满足OpenCV的要求。训练分类器后,我尝试使用它来对训练数据集进行分类,但它将所有样本分类为0类。训练数据集有784个维度,10个类和60000个样本。我的培训代码如下:
#include <iostream>
#include <opencv2/ml.hpp>
using namespace cv;
using namespace cv::ml;
int main(int argc, char* argv[])
{
String trainingDataFile(argv[1]);
Ptr<TrainData> trainingData = TrainData::loadFromCSV(trainingDataFile,0);
Ptr<NormalBayesClassifier> nbClassifier = NormalBayesClassifier::create();
nbClassifier->train(trainingData);
nbClassifier->save(trainingDataFile+"_trainedNBParams.dat");
return 0;
}
测试代码只是从文件中重新加载NB分类器并对所有样本进行分类。我成功地将此代码用于另一个具有128维,10个类和10000个样本的较小数据集。我不确定我的代码,培训方法或OpenCV本身是否存在问题。请指教。
谢谢。
答案 0 :(得分:0)
如果没有其他信息,很难判断问题可能是什么。
我的猜测是在数据中,因为这是唯一已经改变并导致失败的变量。你能尝试迭代训练数据
吗?for(int i = 0; i < trainingData->getSamples().size(); i++) {
// check contents of data
}
看看那里有什么不对吗? (样本数量,尺寸等)